Skip to content

speech_text

Author: Heli Qi Affiliation: NAIST Date: 2022.07

RandomSpkFeatDataset

Bases: SpeechTextDataset

Source code in speechain/dataset/speech_text.py
class RandomSpkFeatDataset(SpeechTextDataset):
    """"""

    def dataset_init_fn(
        self,
        spk_feat: List[str] or str = None,
        use_aver_feat: bool = True,
        mixup_number: int = 1,
        **super_conf,
    ):

        super(RandomSpkFeatDataset, self).dataset_init_fn(**super_conf)

        assert (
            spk_feat is not None
        ), f"spk_feat cannot be None. Please specify it in {self.__class__.__name__}!"
        assert (
            isinstance(mixup_number, int) and mixup_number >= 1
        ), f"mixup_number must be a positive integer, but got {mixup_number}!"
        self.mixup_number = mixup_number

        # List[str] or str -> List[str]
        if not isinstance(spk_feat, List):
            spk_feat = [spk_feat]
        metadata_dir = [os.path.dirname(s_f) for s_f in spk_feat]
        spk_emb_model = [
            os.path.basename(s_f).split("2")[-1].split("_")[0] for s_f in spk_feat
        ]

        # register the list of available speaker IDs
        self.idx2spk = load_idx2data_file(
            [os.path.join(m_d, "idx2spk") for m_d in metadata_dir]
        )
        self.spk_ids_list = sorted(set(self.idx2spk.values()))
        self.spk_num = len(self.spk_ids_list)
        self.spk2freq = {spk_id: 0 for spk_id in self.spk_ids_list}

        # speaker embedding file reading, List[str] -> Dict[str, str]
        idx2spk_feat = load_idx2data_file(spk_feat)
        self.spk2spk_feat = {
            spk_id: {
                spk_feat_id: idx2spk_feat[spk_feat_id]
                for spk_feat_id in idx2spk_feat.keys()
                if self.idx2spk[spk_feat_id] == spk_id
            }
            for spk_id in self.spk_ids_list
        }
        if use_aver_feat:
            self.spk2aver_spk_feat = load_idx2data_file(
                [
                    os.path.join(m_d, f"spk2aver_{s_e_m}_spk_feat")
                    for m_d, s_e_m in zip(metadata_dir, spk_emb_model)
                ]
            )

    def extract_main_data_fn(self, main_data: Dict[str, str]) -> Dict[str, Any] or None:
        """This hook function randomly pick up a speaker embedding feature from the
        given spk_feat file as the reference.

        The randomness is controlled by the `seed` you give in the exp_cfg.
        """
        assert "spk_ids" not in main_data.keys(), (
            f"Please don't give spk_ids to main_data of {self.__class__.__name__}. "
            f"This Dataset is used to evaluate open-set multi-speaker TTS that uses external speaker embedding."
        )
        assert "spk_feat" not in main_data.keys(), (
            f"Please don't give spk_feat to main_data of {self.__class__.__name__}. "
            f"Your spk_feat should be given outside the main_data."
        )

        # process 'feat' and 'text' by the parent class
        main_data = super(RandomSpkFeatDataset, self).extract_main_data_fn(main_data)
        # None means empty batch received from the parent class
        if main_data is None:
            return main_data

        chosen_spk_feat_ids, chosen_spk_ids = [], []
        while len(chosen_spk_feat_ids) < self.mixup_number:
            random_spk_id, self.spk2freq = get_min_indices_by_freq(
                self.spk2freq,
                freq_weights=(
                    len(main_data["text"]) if "text" in main_data.keys() else None
                ),
            )
            random_spk_id = random_spk_id[0]

            # randomly pick up a speaker embedding feature vector
            spk_feat = self.spk2spk_feat[random_spk_id]
            spk_feat_id_list = list(spk_feat.keys())
            random_spk_feat_id = spk_feat_id_list[
                random.randint(0, len(spk_feat_id_list) - 1)
            ]
            if not hasattr(self, "spk2aver_spk_feat"):
                spk_feat = read_data_by_path(
                    spk_feat[random_spk_feat_id], return_tensor=True
                )
            else:
                # randomly pick up a useless spk_feat_id for the same randomness results
                random_spk_feat_id = "aver_spk_feat"
                spk_feat = read_data_by_path(
                    self.spk2aver_spk_feat[random_spk_id], return_tensor=True
                )

            if "spk_feat" not in main_data.keys():
                main_data["spk_feat"] = spk_feat
            else:
                main_data["spk_feat"] += spk_feat

            chosen_spk_feat_ids.append(random_spk_feat_id)
            chosen_spk_ids.append(random_spk_id)

        # take the average of the chose speaker embedding features
        if self.mixup_number > 1:
            main_data["spk_feat"] /= self.mixup_number
            # sort all the IDs of spk_feat and spk to make sure the naming uniqueness
            main_data["spk_feat_ids"] = "+".join(sorted(chosen_spk_feat_ids))
            main_data["spk_ids"] = "+".join(sorted(chosen_spk_ids))
        else:
            main_data["spk_feat_ids"] = chosen_spk_feat_ids[0]
            main_data["spk_ids"] = chosen_spk_ids[0]

        return main_data

extract_main_data_fn(main_data)

This hook function randomly pick up a speaker embedding feature from the given spk_feat file as the reference.

The randomness is controlled by the seed you give in the exp_cfg.

Source code in speechain/dataset/speech_text.py
def extract_main_data_fn(self, main_data: Dict[str, str]) -> Dict[str, Any] or None:
    """This hook function randomly pick up a speaker embedding feature from the
    given spk_feat file as the reference.

    The randomness is controlled by the `seed` you give in the exp_cfg.
    """
    assert "spk_ids" not in main_data.keys(), (
        f"Please don't give spk_ids to main_data of {self.__class__.__name__}. "
        f"This Dataset is used to evaluate open-set multi-speaker TTS that uses external speaker embedding."
    )
    assert "spk_feat" not in main_data.keys(), (
        f"Please don't give spk_feat to main_data of {self.__class__.__name__}. "
        f"Your spk_feat should be given outside the main_data."
    )

    # process 'feat' and 'text' by the parent class
    main_data = super(RandomSpkFeatDataset, self).extract_main_data_fn(main_data)
    # None means empty batch received from the parent class
    if main_data is None:
        return main_data

    chosen_spk_feat_ids, chosen_spk_ids = [], []
    while len(chosen_spk_feat_ids) < self.mixup_number:
        random_spk_id, self.spk2freq = get_min_indices_by_freq(
            self.spk2freq,
            freq_weights=(
                len(main_data["text"]) if "text" in main_data.keys() else None
            ),
        )
        random_spk_id = random_spk_id[0]

        # randomly pick up a speaker embedding feature vector
        spk_feat = self.spk2spk_feat[random_spk_id]
        spk_feat_id_list = list(spk_feat.keys())
        random_spk_feat_id = spk_feat_id_list[
            random.randint(0, len(spk_feat_id_list) - 1)
        ]
        if not hasattr(self, "spk2aver_spk_feat"):
            spk_feat = read_data_by_path(
                spk_feat[random_spk_feat_id], return_tensor=True
            )
        else:
            # randomly pick up a useless spk_feat_id for the same randomness results
            random_spk_feat_id = "aver_spk_feat"
            spk_feat = read_data_by_path(
                self.spk2aver_spk_feat[random_spk_id], return_tensor=True
            )

        if "spk_feat" not in main_data.keys():
            main_data["spk_feat"] = spk_feat
        else:
            main_data["spk_feat"] += spk_feat

        chosen_spk_feat_ids.append(random_spk_feat_id)
        chosen_spk_ids.append(random_spk_id)

    # take the average of the chose speaker embedding features
    if self.mixup_number > 1:
        main_data["spk_feat"] /= self.mixup_number
        # sort all the IDs of spk_feat and spk to make sure the naming uniqueness
        main_data["spk_feat_ids"] = "+".join(sorted(chosen_spk_feat_ids))
        main_data["spk_ids"] = "+".join(sorted(chosen_spk_ids))
    else:
        main_data["spk_feat_ids"] = chosen_spk_feat_ids[0]
        main_data["spk_ids"] = chosen_spk_ids[0]

    return main_data

SpeechTextDataset

Bases: Dataset

This Dataset subclass is mainly used by ASR and TTS models.

In this subclass, each data instance is made up of an utterance and a sentence as well as the speaker information (speaker ID + speaker embedding feature).

Source code in speechain/dataset/speech_text.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
class SpeechTextDataset(Dataset):
    """This Dataset subclass is mainly used by ASR and TTS models.

    In this subclass, each data instance is made up of an utterance and a sentence as
    well as the speaker information (speaker ID + speaker embedding feature).
    """

    def dataset_init_fn(
        self,
        use_g2p: bool = False,
        unk_mask_prob: float = 0.0,
        use_speed_perturb: bool = False,
        sample_rate: int = 16000,
        perturb_range: List[float] = [0.9, 1.0, 1.1],
        pitch_conf: Dict = None,
    ):
        """Dataset initialization function.

        Args:
            use_g2p (bool, optional): Whether to process the raw string by G2P. We don't
                recommend you to turn it on because on-the-fly transformer from string to
                phoneme list consumes a lot of CPU resources. Defaults to False.

            unk_mask_prob (float, optional): Probability of masking tokens as unknown.
                Defaults to 0.0.
            use_speed_perturb (bool, optional): Whether to perturb the speed of the
                waveforms. Defaults to False.
            sample_rate (int, optional): Audio sampling rate in Hz. Defaults to 16000.
            perturb_range (List[float], optional): Range of speed perturbation factors.
                Defaults to [0.9, 1.0, 1.1].

            pitch_conf (Dict, optional): The configuration given to convert_wav_to_pitch()
                for pitch extraction. If not given, pitch extraction will not be done
                on-the-fly. Defaults to None.

        Note:
            Phoneme related: use_g2p
            Waveform related: unk_mask_prob, use_speed_perturb, sample_rate, perturb_range
            Pitch related: pitch_conf
        """
        # register sampling rate for later check
        self.sample_rate = sample_rate
        warnings.warn(
            f"The waveform sampling rate of {self.__class__.__name__} is set to {sample_rate}. "
            f"All the extracted waveforms will be downsampled into {sample_rate} if needed. "
            f"Please make sure that {sample_rate} is the same with your model! "
            f"If this is not your target sampling rate, "
            f"please change it by the key 'sample_rate' in the item 'dataset_conf' under 'data_cfg'. "
            f"If you want to train Language Models or synthesize speech by text, you can ignore this warning."
        )

        assert (
            0 <= unk_mask_prob <= 1
        ), f"unk_mask_prob should be a float number in [0, 1], but got {unk_mask_prob}!"
        self.unk_mask_prob = unk_mask_prob

        # phoneme extraction
        if use_g2p:
            self.g2p = G2p()

        if use_speed_perturb:
            self.perturb_range = perturb_range
            self.speed_resampler_list = [
                torchaudio.transforms.Resample(
                    orig_freq=sample_rate, new_freq=int(sample_rate * factor)
                )
                for factor in perturb_range
            ]

        # pitch extraction
        if pitch_conf is not None:
            if "sr" in pitch_conf.keys():
                assert pitch_conf["sr"] == self.sample_rate, (
                    f"The sampling rate in your given 'pitch_conf' ({pitch_conf['sr']}) is different from your "
                    f"given sample_rate ({self.sample_rate})!"
                )
            pitch_conf["sr"] = self.sample_rate
            self.pitch_extract_fn = partial(
                convert_wav_to_pitch, return_tensor=True, **pitch_conf
            )

    @staticmethod
    def data_len_register_fn(
        main_data: Dict[str, Dict[str, str]]
    ) -> Dict[str, int or float] or None:
        """

        Returns:
            If 'text' is given in main_data, return the number of characters in each sentence.
            Otherwise, return None

        """
        if "text" in main_data.keys():
            return {key: len(value) for key, value in main_data["text"].items()}
        else:
            return None

    def collate_main_data_fn(
        self, batch_dict: Dict[str, List]
    ) -> Dict[str, torch.Tensor or List]:
        """The utterances used for training ASR and TTS models may have different
        lengths, so we need to do the padding operations to make them equal in length.

        The loaded speech feature vectors will be arranged into a single matrix with 0 padding at the end of short
        vectors. Text data remains unprocessed strings and the tokenization will be done later in the model.

        Args:
            batch_dict (Dict[str, List]): The keys of the input `batch_dict` dictionary should be one of the following:
                1. `feat`: a List of 2d `torch.Tensor` with different lengths.
                2. `pitch`: a List of 1d `torch.Tensor` with different lengths.
                3. `text`: a List of text strings.
                4. `spk_ids`: a List of speaker ID strings.
                5. `spk_feat`: a List of 2d `torch.Tensor` with equal lengths.

        Returns:
            A dictionary mapping strings to either torch.Tensor or List, where:
                - feat and spk_feat are three-dimensional torch.Tensor
                - text and spk_ids are lists of raw strings whose discretization is done in the Model object
        """

        # --- 1. Pad Speech Data and Stack them together --- #
        if "feat" in batch_dict.keys():
            # para init
            feat_len = torch.LongTensor([ele.shape[0] for ele in batch_dict["feat"]])
            batch_size, feat_maxlen, feat_dim = (
                len(batch_dict["feat"]),
                feat_len.max().item(),
                batch_dict["feat"][0].shape[-1],
            )

            # acoustic feature padding, feat.dtype needs to match the type of model parameters (torch.float32)
            feat = torch.zeros((batch_size, feat_maxlen, feat_dim), dtype=torch.float32)
            # overwrite the padding matrix with each feat vector
            for i in range(batch_size):
                # process feat data based on data type
                if isinstance(batch_dict["feat"][i], np.ndarray):
                    feat[i][: feat_len[i]] = torch.tensor(batch_dict["feat"][i])
                elif isinstance(batch_dict["feat"][i], torch.Tensor):
                    feat[i][: feat_len[i]] = batch_dict["feat"][i]
                # only support np.ndarray and torch.Tensor now
                else:
                    raise TypeError

            # update 'feat' and attach 'feat_len' for later model forward
            batch_dict["feat"] = feat
            batch_dict["feat_len"] = feat_len

        # --- 2. Pad Pitch Data and Stack them together --- #
        if "pitch" in batch_dict.keys():
            # para init
            pitch_len = torch.LongTensor([ele.shape[0] for ele in batch_dict["pitch"]])
            batch_size, pitch_maxlen = len(batch_dict["pitch"]), pitch_len.max().item()

            # pitch padding, pitch.dtype needs to match the type of model parameters (torch.float32)
            pitch = torch.zeros((batch_size, pitch_maxlen), dtype=torch.float32)
            # overwrite the padding matrix with each pitch vector
            for i in range(batch_size):
                # process feat data based on data type
                if isinstance(batch_dict["pitch"][i], np.ndarray):
                    pitch[i][: pitch_len[i]] = torch.tensor(batch_dict["pitch"][i])
                elif isinstance(batch_dict["pitch"][i], torch.Tensor):
                    pitch[i][: pitch_len[i]] = batch_dict["pitch"][i]
                # only support np.ndarray and torch.Tensor now
                else:
                    raise TypeError

            batch_dict["pitch"] = pitch
            batch_dict["pitch_len"] = pitch_len

        # --- 3. Separate Phoneme Duration Data into Text Data and Duration Data --- #
        if "duration" in batch_dict.keys():
            # para init
            batch_size, duration_len = len(batch_dict["duration"]), torch.LongTensor(
                [len(ele) for ele in batch_dict["duration"]]
            )

            # duration padding, feat.dtype needs to match the type of model parameters (torch.float32)
            duration = torch.zeros(
                (batch_size, duration_len.max().item()), dtype=torch.float32
            )
            # overwrite the padding matrix with each duration vector
            for i in range(batch_size):
                # process duration data based on data type
                if isinstance(batch_dict["duration"][i], (np.ndarray, List)):
                    duration[i][: duration_len[i]] = torch.tensor(
                        batch_dict["duration"][i]
                    )
                elif isinstance(batch_dict["duration"][i], torch.Tensor):
                    duration[i][: duration_len[i]] = batch_dict["duration"][i]
                else:
                    raise TypeError(
                        f"{self.__class__.name} only supports np.ndarray and torch.Tensor now!"
                    )

            # attach 'duration' and 'duration_len' for model forward
            batch_dict["duration"] = duration
            batch_dict["duration_len"] = duration_len

        # --- 4. Stack Speaker Embedding Feature together --- #
        if "spk_feat" in batch_dict.keys():
            batch_dict["spk_feat"] = torch.stack(batch_dict["spk_feat"])

        return batch_dict

    def extract_main_data_fn(self, main_data: Dict) -> Dict[str, Any] or None:
        """The function that loads speech-text data from the disk. If the speech is in
        the form of raw waveforms, the last dimension should be expanded to 1 of raw
        speech for compatibility with acoustic feature.

        Args:
            main_data: Dict[str, str]
                The keys of the input main_data dictionary should be one of the following:
                    1. 'feat': speech features, can be either raw waveforms or acoustic features like log-mel or MFCC.
                    2. 'text': transcript text, in the form of raw string. The tokenization will be done in the ASR and
                    TTS models.
                    3. 'duration': phoneme durations. used for training fastspeech2 model.
                    4. 'spk_ids': speaker ID, in the form of raw string. The speaker discretization will be done in the
                    ASR and TTS models.
                    5. 'spk_feat': speaker embedding features.
                `spk_ids` and `spk_feat` are designed for multi-speaker TTS model and are not mandatory to be included
                in `main_data; 'feat' and 'text' are mandatory to be included for ASR and TTS training.
                However, during model testing, we can choose to only include one of 'feat' and 'text' here to reduce the
                CPU burden.

        Returns:
            `feat` and `spk_feat` are in the form of two-dimensional `torch.Tensor`;
            `text` and `spk_ids` are in the form of raw strings whose discretization is done in the Model object.
        """
        assert (
            "feat" in main_data.keys() or "text" in main_data.keys()
        ), "Please at least include one of 'feat' and 'text' in a single batch."
        for key in main_data.keys():
            if key not in ["feat", "text", "duration", "spk_ids", "spk_feat"]:
                raise RuntimeError(
                    f"Unknown data name {key}! "
                    f"For {self.__class__.__name__}, the key in 'main_data' must be one of "
                    "'feat' (for paths of raw waveforms or acoustic features), "
                    "'text' (for transcript text data), "
                    "'duration' (for phoneme duration data), "
                    "'spk_ids' (for speaker IDs), "
                    "'spk_feat' (for speaker embedding features)."
                )

        # --- 1. Speech Data Extraction --- #
        if "feat" in main_data.keys():
            # read the selected data speech feature as a tensor by its path
            main_data["feat"], sample_rate = read_data_by_path(
                main_data["feat"], return_sample_rate=True, return_tensor=True
            )
            # sometimes the extracted waveform data from an audio file can be empty, skip the current file if that happens
            if main_data["feat"].size(0) == 0:
                return None

            # on-the-fly downsampling if extracted sampling rate is larger than the built-in one
            if sample_rate > self.sample_rate:
                if not hasattr(self, "wav_resampler_dict"):
                    self.wav_resampler_dict = {
                        sample_rate: torchaudio.transforms.Resample(
                            orig_freq=sample_rate, new_freq=self.sample_rate
                        )
                    }
                main_data["feat"] = self.wav_resampler_dict[sample_rate](
                    main_data["feat"].squeeze(-1)
                ).unsqueeze(-1)
            # extracted waveforms could not have lower sampling rate than the built-in one
            elif sample_rate < self.sample_rate:
                raise RuntimeError(
                    f"The current waveform has the lower sampling rate than {self.sample_rate}!"
                )

            # perturb the speed of the extracted speech if specified
            if hasattr(self, "speed_resampler_list"):
                assert sample_rate == self.sample_rate, (
                    f"Your given sample rate ({self.sample_rate}) is different from the real one gotten from the "
                    f"waveform ({sample_rate})!"
                )
                resampler_index = torch.randint(len(self.speed_resampler_list), (1,))[0]
                main_data["feat"] = self.speed_resampler_list[resampler_index](
                    main_data["feat"].squeeze(-1)
                ).unsqueeze(-1)

            # extract the pitch from the speech on-the-fly
            if hasattr(self, "pitch_extract_fn"):
                try:
                    main_data["pitch"] = self.pitch_extract_fn(main_data["feat"])
                # IndexError means all the pitch values are unvoiced (=0.0)
                # return None to remove this utterance from the current batch
                except IndexError:
                    return None

        # --- 2. Transcript Text Extraction --- #
        if "text" in main_data.keys():
            # text length is not returned because the text here is just a raw string
            assert isinstance(
                main_data["text"], str
            ), f"The 'text' data should be given as a string, but got {main_data['text']}"
            # for the text data in the format of a list
            if main_data["text"].startswith("[") and main_data["text"].endswith("]"):
                main_data["text"] = main_data["text"][1:-1]
                # split the text into individual tokens by a comma followed a blank
                main_data["text"] = main_data["text"].split(", ")
                # remove the single quote marks surrounding each token if needed
                main_data["text"] = [
                    (
                        token[1:-1]
                        if token.startswith("'") and token.endswith("'")
                        else token
                    )
                    for token in main_data["text"]
                ]
            # process the raw string by G2P if specified
            elif hasattr(self, "g2p"):
                phn_list = self.g2p(main_data["text"])
                main_data["text"] = [
                    phn if phn != " " else "<space>"
                    for phn in phn_list
                    if phn not in abnormal_phns
                ]

        # --- 3. Phoneme Duration Extraction --- #
        if "duration" in main_data.keys():
            # text length is not returned because the text here is just a raw string
            assert isinstance(
                main_data["duration"], str
            ), f"The 'duration' data should be given as a string, but got {main_data['duration']}"
            # for the text data in the format of a list
            if main_data["duration"].startswith("[") and main_data["duration"].endswith(
                "]"
            ):
                main_data["duration"] = main_data["duration"][1:-1]
                # split the text into individual tokens by a comma followed a blank
                main_data["duration"] = main_data["duration"].split(", ")
                # remove the single quote marks surrounding each token if needed
                main_data["duration"] = [
                    (
                        float(duration[1:-1])
                        if duration.startswith("'") and duration.endswith("'")
                        else float(duration)
                    )
                    for duration in main_data["duration"]
                ]
            else:
                raise RuntimeError(
                    "The 'duration' string should be surrounded by a pair of square brackets!"
                )

        # --- 4. Silence Trimming at the two ends --- #
        # trim the silence at two ends of the waveforms if the phoneme sequence starts or ends with spaces
        if ("text" in main_data.keys() and isinstance(main_data["text"], List)) and (
            main_data["text"][0] == "<space>" or main_data["text"][-1] == "<space>"
        ):
            # trim both feat and text
            if "feat" in main_data.keys():
                assert "duration" in main_data.keys(), (
                    "If you want to trim the silence at two ends of speech, "
                    "please give 'duration' in 'main_data' of the item 'dataset_conf' under 'data_cfg'."
                )
                front_trim_len, tail_trim_len, total_duration = (
                    0,
                    0,
                    sum(main_data["duration"]),
                )
                try:
                    # sum up all the silence tokens at the beginning
                    while main_data["text"][0] == "<space>":
                        front_trim_len += main_data["duration"][0]
                        main_data["text"], main_data["duration"] = (
                            main_data["text"][1:],
                            main_data["duration"][1:],
                        )
                    # sum up all the silence tokens at the end
                    while main_data["text"][-1] == "<space>":
                        tail_trim_len += main_data["duration"][-1]
                        main_data["text"], main_data["duration"] = (
                            main_data["text"][:-1],
                            main_data["duration"][:-1],
                        )
                # IndexError means the text is full of '<space>'
                # return None to remove this utterance from the current batch
                except IndexError:
                    return None

                # normalize the trimming lengths by the total duration length
                front_trim_len, tail_trim_len = (
                    front_trim_len / total_duration,
                    tail_trim_len / total_duration,
                )
                # trim the extra silence in feat (waveforms or acoustic features)
                feat_start, feat_end = int(
                    front_trim_len * len(main_data["feat"])
                ), int(tail_trim_len * len(main_data["feat"]))
                main_data["feat"] = main_data["feat"][feat_start:]
                if feat_end > 0:
                    main_data["feat"] = main_data["feat"][:-feat_end]

                # also trim the two ends of pitch values if extracted
                if "pitch" in main_data.keys():
                    pitch_start, pitch_end = int(
                        front_trim_len * len(main_data["pitch"])
                    ), int(tail_trim_len * len(main_data["pitch"]))
                    main_data["pitch"] = main_data["pitch"][pitch_start:]
                    if pitch_end > 0:
                        main_data["pitch"] = main_data["pitch"][:-pitch_end]

            # only trim text if feat is not given
            else:
                try:
                    # sum up all the <space> tokens at the beginning
                    while main_data["text"][0] == "<space>":
                        main_data["text"] = main_data["text"][1:]
                        if "duration" in main_data.keys():
                            main_data["duration"] = main_data["duration"][1:]
                    # sum up all the <space> tokens at the end
                    while main_data["text"][-1] == "<space>":
                        main_data["text"] = main_data["text"][:-1]
                        if "duration" in main_data.keys():
                            main_data["duration"] = main_data["duration"][:-1]
                # IndexError means the text is full of '<space>'
                # return None to remove this utterance from the current batch
                except IndexError:
                    return None

        # --- 5. Randomly Masking the text data by unknown tokens (After silence trimming for data safety) --- #
        if self.unk_mask_prob > 0:
            assert "text" in main_data.keys() and isinstance(
                main_data["text"], List
            ), "If you want to activate unk_mask_prob, text must be given in the 'main_date' tag as a token sequence."

            # Get the start and end indices of words based on the positions of space tokens
            space_indices = [
                i for i, token in enumerate(main_data["text"]) if token == "<space>"
            ]
            word_start_indices, word_end_indices = [0] + [
                s_i + 1 for s_i in space_indices
            ], space_indices + [len(main_data["text"])]

            # Determine which words to mask
            word_mask_flags = (
                np.random.rand(len(word_start_indices)) < self.unk_mask_prob
            )

            _tmp_text, _tmp_duration = [], []
            for i in range(len(word_mask_flags)):
                # If the word should be masked, add an '<unk>' token
                if word_mask_flags[i]:
                    _tmp_text.append("<unk>")
                    if "duration" in main_data.keys():
                        _sum_duration = sum(
                            main_data["duration"][
                                word_start_indices[i] : word_end_indices[i]
                            ]
                        )
                        _tmp_duration.append(round(_sum_duration, 2))

                # If the word shouldn't be masked, add the original tokens of the word
                else:
                    _tmp_text += main_data["text"][
                        word_start_indices[i] : word_end_indices[i]
                    ]
                    if "duration" in main_data.keys():
                        _tmp_duration += main_data["duration"][
                            word_start_indices[i] : word_end_indices[i]
                        ]

                # Add space tokens and their durations between words, except for the last word
                if i != len(word_mask_flags) - 1:
                    _tmp_text.append(main_data["text"][word_end_indices[i]])
                    if "duration" in main_data.keys():
                        _tmp_duration.append(main_data["duration"][word_end_indices[i]])

            # Update main_data with the new text and duration information
            main_data["text"] = _tmp_text
            if "duration" in main_data.keys():
                main_data["duration"] = _tmp_duration

        # --- 6. Speaker ID Extraction --- #
        if "spk_ids" in main_data.keys():
            # the speaker ID here is just a raw string
            assert isinstance(
                main_data["spk_ids"], str
            ), f"The 'spk_ids' data should be given as a string, but got {main_data['spk_ids']}"

        # --- 7. Speaker Embedding Feature --- #
        if "spk_feat" in main_data.keys():
            # read the selected data speech feature as a tensor by its path
            main_data["spk_feat"] = read_data_by_path(
                main_data["spk_feat"], return_tensor=True
            )

        return main_data

    def __repr__(self):
        outputs = f"{self.__class__.__name__}(sample_rate={self.sample_rate}"
        if hasattr(self, "g2p"):
            outputs += ", use_g2p=True"
        if hasattr(self, "speed_resampler_list"):
            outputs += f", speed_perturb_range={self.perturb_range}"
        if hasattr(self, "pitch_extract_fn"):
            outputs += ", pitch_extract=True"
        if self.unk_mask_prob > 0:
            outputs += f", unk_mask_prob={self.unk_mask_prob}"
        return outputs + ")"

collate_main_data_fn(batch_dict)

The utterances used for training ASR and TTS models may have different lengths, so we need to do the padding operations to make them equal in length.

The loaded speech feature vectors will be arranged into a single matrix with 0 padding at the end of short vectors. Text data remains unprocessed strings and the tokenization will be done later in the model.

Parameters:

Name Type Description Default
batch_dict Dict[str, List]

The keys of the input batch_dict dictionary should be one of the following: 1. feat: a List of 2d torch.Tensor with different lengths. 2. pitch: a List of 1d torch.Tensor with different lengths. 3. text: a List of text strings. 4. spk_ids: a List of speaker ID strings. 5. spk_feat: a List of 2d torch.Tensor with equal lengths.

required

Returns:

Type Description
Dict[str, Tensor or List]

A dictionary mapping strings to either torch.Tensor or List, where: - feat and spk_feat are three-dimensional torch.Tensor - text and spk_ids are lists of raw strings whose discretization is done in the Model object

Source code in speechain/dataset/speech_text.py
def collate_main_data_fn(
    self, batch_dict: Dict[str, List]
) -> Dict[str, torch.Tensor or List]:
    """The utterances used for training ASR and TTS models may have different
    lengths, so we need to do the padding operations to make them equal in length.

    The loaded speech feature vectors will be arranged into a single matrix with 0 padding at the end of short
    vectors. Text data remains unprocessed strings and the tokenization will be done later in the model.

    Args:
        batch_dict (Dict[str, List]): The keys of the input `batch_dict` dictionary should be one of the following:
            1. `feat`: a List of 2d `torch.Tensor` with different lengths.
            2. `pitch`: a List of 1d `torch.Tensor` with different lengths.
            3. `text`: a List of text strings.
            4. `spk_ids`: a List of speaker ID strings.
            5. `spk_feat`: a List of 2d `torch.Tensor` with equal lengths.

    Returns:
        A dictionary mapping strings to either torch.Tensor or List, where:
            - feat and spk_feat are three-dimensional torch.Tensor
            - text and spk_ids are lists of raw strings whose discretization is done in the Model object
    """

    # --- 1. Pad Speech Data and Stack them together --- #
    if "feat" in batch_dict.keys():
        # para init
        feat_len = torch.LongTensor([ele.shape[0] for ele in batch_dict["feat"]])
        batch_size, feat_maxlen, feat_dim = (
            len(batch_dict["feat"]),
            feat_len.max().item(),
            batch_dict["feat"][0].shape[-1],
        )

        # acoustic feature padding, feat.dtype needs to match the type of model parameters (torch.float32)
        feat = torch.zeros((batch_size, feat_maxlen, feat_dim), dtype=torch.float32)
        # overwrite the padding matrix with each feat vector
        for i in range(batch_size):
            # process feat data based on data type
            if isinstance(batch_dict["feat"][i], np.ndarray):
                feat[i][: feat_len[i]] = torch.tensor(batch_dict["feat"][i])
            elif isinstance(batch_dict["feat"][i], torch.Tensor):
                feat[i][: feat_len[i]] = batch_dict["feat"][i]
            # only support np.ndarray and torch.Tensor now
            else:
                raise TypeError

        # update 'feat' and attach 'feat_len' for later model forward
        batch_dict["feat"] = feat
        batch_dict["feat_len"] = feat_len

    # --- 2. Pad Pitch Data and Stack them together --- #
    if "pitch" in batch_dict.keys():
        # para init
        pitch_len = torch.LongTensor([ele.shape[0] for ele in batch_dict["pitch"]])
        batch_size, pitch_maxlen = len(batch_dict["pitch"]), pitch_len.max().item()

        # pitch padding, pitch.dtype needs to match the type of model parameters (torch.float32)
        pitch = torch.zeros((batch_size, pitch_maxlen), dtype=torch.float32)
        # overwrite the padding matrix with each pitch vector
        for i in range(batch_size):
            # process feat data based on data type
            if isinstance(batch_dict["pitch"][i], np.ndarray):
                pitch[i][: pitch_len[i]] = torch.tensor(batch_dict["pitch"][i])
            elif isinstance(batch_dict["pitch"][i], torch.Tensor):
                pitch[i][: pitch_len[i]] = batch_dict["pitch"][i]
            # only support np.ndarray and torch.Tensor now
            else:
                raise TypeError

        batch_dict["pitch"] = pitch
        batch_dict["pitch_len"] = pitch_len

    # --- 3. Separate Phoneme Duration Data into Text Data and Duration Data --- #
    if "duration" in batch_dict.keys():
        # para init
        batch_size, duration_len = len(batch_dict["duration"]), torch.LongTensor(
            [len(ele) for ele in batch_dict["duration"]]
        )

        # duration padding, feat.dtype needs to match the type of model parameters (torch.float32)
        duration = torch.zeros(
            (batch_size, duration_len.max().item()), dtype=torch.float32
        )
        # overwrite the padding matrix with each duration vector
        for i in range(batch_size):
            # process duration data based on data type
            if isinstance(batch_dict["duration"][i], (np.ndarray, List)):
                duration[i][: duration_len[i]] = torch.tensor(
                    batch_dict["duration"][i]
                )
            elif isinstance(batch_dict["duration"][i], torch.Tensor):
                duration[i][: duration_len[i]] = batch_dict["duration"][i]
            else:
                raise TypeError(
                    f"{self.__class__.name} only supports np.ndarray and torch.Tensor now!"
                )

        # attach 'duration' and 'duration_len' for model forward
        batch_dict["duration"] = duration
        batch_dict["duration_len"] = duration_len

    # --- 4. Stack Speaker Embedding Feature together --- #
    if "spk_feat" in batch_dict.keys():
        batch_dict["spk_feat"] = torch.stack(batch_dict["spk_feat"])

    return batch_dict

data_len_register_fn(main_data) staticmethod

Returns:

Type Description
Dict[str, int or float] or None

If 'text' is given in main_data, return the number of characters in each sentence.

Dict[str, int or float] or None

Otherwise, return None

Source code in speechain/dataset/speech_text.py
@staticmethod
def data_len_register_fn(
    main_data: Dict[str, Dict[str, str]]
) -> Dict[str, int or float] or None:
    """

    Returns:
        If 'text' is given in main_data, return the number of characters in each sentence.
        Otherwise, return None

    """
    if "text" in main_data.keys():
        return {key: len(value) for key, value in main_data["text"].items()}
    else:
        return None

dataset_init_fn(use_g2p=False, unk_mask_prob=0.0, use_speed_perturb=False, sample_rate=16000, perturb_range=[0.9, 1.0, 1.1], pitch_conf=None)

Dataset initialization function.

Parameters:

Name Type Description Default
use_g2p bool

Whether to process the raw string by G2P. We don't recommend you to turn it on because on-the-fly transformer from string to phoneme list consumes a lot of CPU resources. Defaults to False.

False
unk_mask_prob float

Probability of masking tokens as unknown. Defaults to 0.0.

0.0
use_speed_perturb bool

Whether to perturb the speed of the waveforms. Defaults to False.

False
sample_rate int

Audio sampling rate in Hz. Defaults to 16000.

16000
perturb_range List[float]

Range of speed perturbation factors. Defaults to [0.9, 1.0, 1.1].

[0.9, 1.0, 1.1]
pitch_conf Dict

The configuration given to convert_wav_to_pitch() for pitch extraction. If not given, pitch extraction will not be done on-the-fly. Defaults to None.

None
Note

Phoneme related: use_g2p Waveform related: unk_mask_prob, use_speed_perturb, sample_rate, perturb_range Pitch related: pitch_conf

Source code in speechain/dataset/speech_text.py
def dataset_init_fn(
    self,
    use_g2p: bool = False,
    unk_mask_prob: float = 0.0,
    use_speed_perturb: bool = False,
    sample_rate: int = 16000,
    perturb_range: List[float] = [0.9, 1.0, 1.1],
    pitch_conf: Dict = None,
):
    """Dataset initialization function.

    Args:
        use_g2p (bool, optional): Whether to process the raw string by G2P. We don't
            recommend you to turn it on because on-the-fly transformer from string to
            phoneme list consumes a lot of CPU resources. Defaults to False.

        unk_mask_prob (float, optional): Probability of masking tokens as unknown.
            Defaults to 0.0.
        use_speed_perturb (bool, optional): Whether to perturb the speed of the
            waveforms. Defaults to False.
        sample_rate (int, optional): Audio sampling rate in Hz. Defaults to 16000.
        perturb_range (List[float], optional): Range of speed perturbation factors.
            Defaults to [0.9, 1.0, 1.1].

        pitch_conf (Dict, optional): The configuration given to convert_wav_to_pitch()
            for pitch extraction. If not given, pitch extraction will not be done
            on-the-fly. Defaults to None.

    Note:
        Phoneme related: use_g2p
        Waveform related: unk_mask_prob, use_speed_perturb, sample_rate, perturb_range
        Pitch related: pitch_conf
    """
    # register sampling rate for later check
    self.sample_rate = sample_rate
    warnings.warn(
        f"The waveform sampling rate of {self.__class__.__name__} is set to {sample_rate}. "
        f"All the extracted waveforms will be downsampled into {sample_rate} if needed. "
        f"Please make sure that {sample_rate} is the same with your model! "
        f"If this is not your target sampling rate, "
        f"please change it by the key 'sample_rate' in the item 'dataset_conf' under 'data_cfg'. "
        f"If you want to train Language Models or synthesize speech by text, you can ignore this warning."
    )

    assert (
        0 <= unk_mask_prob <= 1
    ), f"unk_mask_prob should be a float number in [0, 1], but got {unk_mask_prob}!"
    self.unk_mask_prob = unk_mask_prob

    # phoneme extraction
    if use_g2p:
        self.g2p = G2p()

    if use_speed_perturb:
        self.perturb_range = perturb_range
        self.speed_resampler_list = [
            torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=int(sample_rate * factor)
            )
            for factor in perturb_range
        ]

    # pitch extraction
    if pitch_conf is not None:
        if "sr" in pitch_conf.keys():
            assert pitch_conf["sr"] == self.sample_rate, (
                f"The sampling rate in your given 'pitch_conf' ({pitch_conf['sr']}) is different from your "
                f"given sample_rate ({self.sample_rate})!"
            )
        pitch_conf["sr"] = self.sample_rate
        self.pitch_extract_fn = partial(
            convert_wav_to_pitch, return_tensor=True, **pitch_conf
        )

extract_main_data_fn(main_data)

The function that loads speech-text data from the disk. If the speech is in the form of raw waveforms, the last dimension should be expanded to 1 of raw speech for compatibility with acoustic feature.

Parameters:

Name Type Description Default
main_data Dict

Dict[str, str] The keys of the input main_data dictionary should be one of the following: 1. 'feat': speech features, can be either raw waveforms or acoustic features like log-mel or MFCC. 2. 'text': transcript text, in the form of raw string. The tokenization will be done in the ASR and TTS models. 3. 'duration': phoneme durations. used for training fastspeech2 model. 4. 'spk_ids': speaker ID, in the form of raw string. The speaker discretization will be done in the ASR and TTS models. 5. 'spk_feat': speaker embedding features. spk_ids and spk_feat are designed for multi-speaker TTS model and are not mandatory to be included in `main_data; 'feat' and 'text' are mandatory to be included for ASR and TTS training. However, during model testing, we can choose to only include one of 'feat' and 'text' here to reduce the CPU burden.

required

Returns:

Type Description
Dict[str, Any] or None

feat and spk_feat are in the form of two-dimensional torch.Tensor;

Dict[str, Any] or None

text and spk_ids are in the form of raw strings whose discretization is done in the Model object.

Source code in speechain/dataset/speech_text.py
def extract_main_data_fn(self, main_data: Dict) -> Dict[str, Any] or None:
    """The function that loads speech-text data from the disk. If the speech is in
    the form of raw waveforms, the last dimension should be expanded to 1 of raw
    speech for compatibility with acoustic feature.

    Args:
        main_data: Dict[str, str]
            The keys of the input main_data dictionary should be one of the following:
                1. 'feat': speech features, can be either raw waveforms or acoustic features like log-mel or MFCC.
                2. 'text': transcript text, in the form of raw string. The tokenization will be done in the ASR and
                TTS models.
                3. 'duration': phoneme durations. used for training fastspeech2 model.
                4. 'spk_ids': speaker ID, in the form of raw string. The speaker discretization will be done in the
                ASR and TTS models.
                5. 'spk_feat': speaker embedding features.
            `spk_ids` and `spk_feat` are designed for multi-speaker TTS model and are not mandatory to be included
            in `main_data; 'feat' and 'text' are mandatory to be included for ASR and TTS training.
            However, during model testing, we can choose to only include one of 'feat' and 'text' here to reduce the
            CPU burden.

    Returns:
        `feat` and `spk_feat` are in the form of two-dimensional `torch.Tensor`;
        `text` and `spk_ids` are in the form of raw strings whose discretization is done in the Model object.
    """
    assert (
        "feat" in main_data.keys() or "text" in main_data.keys()
    ), "Please at least include one of 'feat' and 'text' in a single batch."
    for key in main_data.keys():
        if key not in ["feat", "text", "duration", "spk_ids", "spk_feat"]:
            raise RuntimeError(
                f"Unknown data name {key}! "
                f"For {self.__class__.__name__}, the key in 'main_data' must be one of "
                "'feat' (for paths of raw waveforms or acoustic features), "
                "'text' (for transcript text data), "
                "'duration' (for phoneme duration data), "
                "'spk_ids' (for speaker IDs), "
                "'spk_feat' (for speaker embedding features)."
            )

    # --- 1. Speech Data Extraction --- #
    if "feat" in main_data.keys():
        # read the selected data speech feature as a tensor by its path
        main_data["feat"], sample_rate = read_data_by_path(
            main_data["feat"], return_sample_rate=True, return_tensor=True
        )
        # sometimes the extracted waveform data from an audio file can be empty, skip the current file if that happens
        if main_data["feat"].size(0) == 0:
            return None

        # on-the-fly downsampling if extracted sampling rate is larger than the built-in one
        if sample_rate > self.sample_rate:
            if not hasattr(self, "wav_resampler_dict"):
                self.wav_resampler_dict = {
                    sample_rate: torchaudio.transforms.Resample(
                        orig_freq=sample_rate, new_freq=self.sample_rate
                    )
                }
            main_data["feat"] = self.wav_resampler_dict[sample_rate](
                main_data["feat"].squeeze(-1)
            ).unsqueeze(-1)
        # extracted waveforms could not have lower sampling rate than the built-in one
        elif sample_rate < self.sample_rate:
            raise RuntimeError(
                f"The current waveform has the lower sampling rate than {self.sample_rate}!"
            )

        # perturb the speed of the extracted speech if specified
        if hasattr(self, "speed_resampler_list"):
            assert sample_rate == self.sample_rate, (
                f"Your given sample rate ({self.sample_rate}) is different from the real one gotten from the "
                f"waveform ({sample_rate})!"
            )
            resampler_index = torch.randint(len(self.speed_resampler_list), (1,))[0]
            main_data["feat"] = self.speed_resampler_list[resampler_index](
                main_data["feat"].squeeze(-1)
            ).unsqueeze(-1)

        # extract the pitch from the speech on-the-fly
        if hasattr(self, "pitch_extract_fn"):
            try:
                main_data["pitch"] = self.pitch_extract_fn(main_data["feat"])
            # IndexError means all the pitch values are unvoiced (=0.0)
            # return None to remove this utterance from the current batch
            except IndexError:
                return None

    # --- 2. Transcript Text Extraction --- #
    if "text" in main_data.keys():
        # text length is not returned because the text here is just a raw string
        assert isinstance(
            main_data["text"], str
        ), f"The 'text' data should be given as a string, but got {main_data['text']}"
        # for the text data in the format of a list
        if main_data["text"].startswith("[") and main_data["text"].endswith("]"):
            main_data["text"] = main_data["text"][1:-1]
            # split the text into individual tokens by a comma followed a blank
            main_data["text"] = main_data["text"].split(", ")
            # remove the single quote marks surrounding each token if needed
            main_data["text"] = [
                (
                    token[1:-1]
                    if token.startswith("'") and token.endswith("'")
                    else token
                )
                for token in main_data["text"]
            ]
        # process the raw string by G2P if specified
        elif hasattr(self, "g2p"):
            phn_list = self.g2p(main_data["text"])
            main_data["text"] = [
                phn if phn != " " else "<space>"
                for phn in phn_list
                if phn not in abnormal_phns
            ]

    # --- 3. Phoneme Duration Extraction --- #
    if "duration" in main_data.keys():
        # text length is not returned because the text here is just a raw string
        assert isinstance(
            main_data["duration"], str
        ), f"The 'duration' data should be given as a string, but got {main_data['duration']}"
        # for the text data in the format of a list
        if main_data["duration"].startswith("[") and main_data["duration"].endswith(
            "]"
        ):
            main_data["duration"] = main_data["duration"][1:-1]
            # split the text into individual tokens by a comma followed a blank
            main_data["duration"] = main_data["duration"].split(", ")
            # remove the single quote marks surrounding each token if needed
            main_data["duration"] = [
                (
                    float(duration[1:-1])
                    if duration.startswith("'") and duration.endswith("'")
                    else float(duration)
                )
                for duration in main_data["duration"]
            ]
        else:
            raise RuntimeError(
                "The 'duration' string should be surrounded by a pair of square brackets!"
            )

    # --- 4. Silence Trimming at the two ends --- #
    # trim the silence at two ends of the waveforms if the phoneme sequence starts or ends with spaces
    if ("text" in main_data.keys() and isinstance(main_data["text"], List)) and (
        main_data["text"][0] == "<space>" or main_data["text"][-1] == "<space>"
    ):
        # trim both feat and text
        if "feat" in main_data.keys():
            assert "duration" in main_data.keys(), (
                "If you want to trim the silence at two ends of speech, "
                "please give 'duration' in 'main_data' of the item 'dataset_conf' under 'data_cfg'."
            )
            front_trim_len, tail_trim_len, total_duration = (
                0,
                0,
                sum(main_data["duration"]),
            )
            try:
                # sum up all the silence tokens at the beginning
                while main_data["text"][0] == "<space>":
                    front_trim_len += main_data["duration"][0]
                    main_data["text"], main_data["duration"] = (
                        main_data["text"][1:],
                        main_data["duration"][1:],
                    )
                # sum up all the silence tokens at the end
                while main_data["text"][-1] == "<space>":
                    tail_trim_len += main_data["duration"][-1]
                    main_data["text"], main_data["duration"] = (
                        main_data["text"][:-1],
                        main_data["duration"][:-1],
                    )
            # IndexError means the text is full of '<space>'
            # return None to remove this utterance from the current batch
            except IndexError:
                return None

            # normalize the trimming lengths by the total duration length
            front_trim_len, tail_trim_len = (
                front_trim_len / total_duration,
                tail_trim_len / total_duration,
            )
            # trim the extra silence in feat (waveforms or acoustic features)
            feat_start, feat_end = int(
                front_trim_len * len(main_data["feat"])
            ), int(tail_trim_len * len(main_data["feat"]))
            main_data["feat"] = main_data["feat"][feat_start:]
            if feat_end > 0:
                main_data["feat"] = main_data["feat"][:-feat_end]

            # also trim the two ends of pitch values if extracted
            if "pitch" in main_data.keys():
                pitch_start, pitch_end = int(
                    front_trim_len * len(main_data["pitch"])
                ), int(tail_trim_len * len(main_data["pitch"]))
                main_data["pitch"] = main_data["pitch"][pitch_start:]
                if pitch_end > 0:
                    main_data["pitch"] = main_data["pitch"][:-pitch_end]

        # only trim text if feat is not given
        else:
            try:
                # sum up all the <space> tokens at the beginning
                while main_data["text"][0] == "<space>":
                    main_data["text"] = main_data["text"][1:]
                    if "duration" in main_data.keys():
                        main_data["duration"] = main_data["duration"][1:]
                # sum up all the <space> tokens at the end
                while main_data["text"][-1] == "<space>":
                    main_data["text"] = main_data["text"][:-1]
                    if "duration" in main_data.keys():
                        main_data["duration"] = main_data["duration"][:-1]
            # IndexError means the text is full of '<space>'
            # return None to remove this utterance from the current batch
            except IndexError:
                return None

    # --- 5. Randomly Masking the text data by unknown tokens (After silence trimming for data safety) --- #
    if self.unk_mask_prob > 0:
        assert "text" in main_data.keys() and isinstance(
            main_data["text"], List
        ), "If you want to activate unk_mask_prob, text must be given in the 'main_date' tag as a token sequence."

        # Get the start and end indices of words based on the positions of space tokens
        space_indices = [
            i for i, token in enumerate(main_data["text"]) if token == "<space>"
        ]
        word_start_indices, word_end_indices = [0] + [
            s_i + 1 for s_i in space_indices
        ], space_indices + [len(main_data["text"])]

        # Determine which words to mask
        word_mask_flags = (
            np.random.rand(len(word_start_indices)) < self.unk_mask_prob
        )

        _tmp_text, _tmp_duration = [], []
        for i in range(len(word_mask_flags)):
            # If the word should be masked, add an '<unk>' token
            if word_mask_flags[i]:
                _tmp_text.append("<unk>")
                if "duration" in main_data.keys():
                    _sum_duration = sum(
                        main_data["duration"][
                            word_start_indices[i] : word_end_indices[i]
                        ]
                    )
                    _tmp_duration.append(round(_sum_duration, 2))

            # If the word shouldn't be masked, add the original tokens of the word
            else:
                _tmp_text += main_data["text"][
                    word_start_indices[i] : word_end_indices[i]
                ]
                if "duration" in main_data.keys():
                    _tmp_duration += main_data["duration"][
                        word_start_indices[i] : word_end_indices[i]
                    ]

            # Add space tokens and their durations between words, except for the last word
            if i != len(word_mask_flags) - 1:
                _tmp_text.append(main_data["text"][word_end_indices[i]])
                if "duration" in main_data.keys():
                    _tmp_duration.append(main_data["duration"][word_end_indices[i]])

        # Update main_data with the new text and duration information
        main_data["text"] = _tmp_text
        if "duration" in main_data.keys():
            main_data["duration"] = _tmp_duration

    # --- 6. Speaker ID Extraction --- #
    if "spk_ids" in main_data.keys():
        # the speaker ID here is just a raw string
        assert isinstance(
            main_data["spk_ids"], str
        ), f"The 'spk_ids' data should be given as a string, but got {main_data['spk_ids']}"

    # --- 7. Speaker Embedding Feature --- #
    if "spk_feat" in main_data.keys():
        # read the selected data speech feature as a tensor by its path
        main_data["spk_feat"] = read_data_by_path(
            main_data["spk_feat"], return_tensor=True
        )

    return main_data