Skip to content

feat_norm

Author: Heli Qi Affiliation: NAIST Date: 2022.09

FeatureNormalization

Bases: Module

The feature normalization frontend that makes every feature dimension the distribution with 0 mean and 1 variance.

As SpeechBrain, we also provide four kinds of feature normalization with different granularities. 1. utterance-level normalization: the mean and std are calculated on each individual utterance. 2. batch-level normalization: the mean and std are calculated on all the utterances in a training batch. 3. group-level normalization: the mean and std are calculated on all the utterances in a group. The group here means where the utterance comes from, so it can be any kinds of data domains such as different speakers, genders, source and target domains in Domain Adaptation scenario, and so on... 4. global-level normalization: the mean and std are calculated on all the utterances in the training set.

We approximate group-level and global-level mean & std by taking their moving average during training. Different from SpeechBrain, we initialize all the mean & std variables lazily in the forward() function. Another difference is that our moving average is calculated by each batch as BatchNorm does.

In the DDP mode, the mean & std will be synchronized across all the processes before being used to normalize the input utterances. The synchronization method is different in different scenarios.

  1. group-level normalization where each input utterance has different group id (group_ids = torch.Tensor, e.g. different utterances in a single batch may belong to different speakers). In this scenario, the mean & std vectors of each utterance and the group ids will be gathered across all the processes. Then, the mean & std vectors will be picked up depending on the group id and the mean & std of the specific group will be calculated.

  2. global-level normalization or group-level normalization where all the input utterances have the same group id (group_ids = str or int, e.g. all the utterances in the batch come from either the source domain or the target domain). In this scenario, the summation of mean & std vectors will be gathered instead of all of them to reduce the data communication volume across all the processes. The real mean & std vectors will be recovered by the batch size of each process.

Source code in speechain/module/norm/feat_norm.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
class FeatureNormalization(Module):
    """The feature normalization frontend that makes every feature dimension the
    distribution with 0 mean and 1 variance.

    As SpeechBrain, we also provide four kinds of feature normalization with different granularities.
        1. utterance-level normalization: the mean and std are calculated on each individual utterance.
        2. batch-level normalization: the mean and std are calculated on all the utterances in a training batch.
        3. group-level normalization: the mean and std are calculated on all the utterances in a group.
            The group here means where the utterance comes from, so it can be any kinds of data domains
            such as different speakers, genders, source and target domains in Domain Adaptation scenario, and so on...
        4. global-level normalization: the mean and std are calculated on all the utterances in the training set.

    We approximate group-level and global-level mean & std by taking their moving average during training.
    Different from SpeechBrain, we initialize all the mean & std variables lazily in the forward() function.
    Another difference is that our moving average is calculated by each batch as BatchNorm does.

    In the DDP mode, the mean & std will be synchronized across all the processes before being used to normalize the
    input utterances. The synchronization method is different in different scenarios.

    1. group-level normalization where each input utterance has different group id (group_ids = torch.Tensor,
        e.g. different utterances in a single batch may belong to different speakers).
        In this scenario, the mean & std vectors of each utterance and the group ids will be gathered across all the
        processes. Then, the mean & std vectors will be picked up depending on the group id and the mean & std of the
        specific group will be calculated.

    2. global-level normalization or group-level normalization where all the input utterances have the same group id
        (group_ids = str or int, e.g. all the utterances in the batch come from either the source domain or the target
        domain). In this scenario, the summation of mean & std vectors will be gathered instead of all of them to reduce
        the data communication volume across all the processes. The real mean & std vectors will be recovered by the
        batch size of each process.
    """

    def module_init(
        self,
        norm_type: str = "global",
        mean_norm: bool = True,
        std_norm: bool = True,
        clamp: float = 1e-10,
        max_epoch_num: int = 4,
    ):
        """

        Args:
            norm_type: str
                The type of feature normalization.
                The type must be one of 'utterance', 'batch', 'group', and 'global'
            mean_norm: bool
                Controls whether the feature vectors will be normalized by their means
            std_norm: bool
                Controls whether the feature vectors will be normalized by their standard variance
            clamp: float
                Clamping threshold for the standard variance before division.
            max_epoch_num: int
                The maximum number of epochs used to calculate the moving average.
                Usually, the value of this argument is lower than a half of the number of warmup epochs.

        """
        self.norm_type = norm_type
        self.mean_norm = mean_norm
        self.std_norm = std_norm
        self.clamp = clamp
        self.max_epoch_num = max_epoch_num

        if self.input_size is not None:
            self.output_size = self.input_size

    def forward(
        self,
        feat: torch.Tensor,
        feat_len: torch.Tensor,
        group_ids: torch.Tensor or str or int = None,
        epoch: int = None,
    ):
        """

        Args:
            feat: (batch, length, channel) or (batch, length)
                The normalization will be done on the channel dimension.
                If the feat is in the shape of (batch, length), it will be extended to (batch, length, 1)
            feat_len: (batch)
            group_ids: (batch)
            epoch:

        Returns:

        """
        if self.norm_type == "group":
            assert group_ids is not None, (
                "You are using group-level feature normalization, but group_ids is not given. "
                "Please check 'data_cfg' in your configuration."
            )
        # para preparation
        batch_size, squeeze_flag = feat.size(0), False
        if len(feat.shape) == 2:
            feat, squeeze_flag = feat.unsqueeze(-1), True
        elif len(feat.shape) != 3:
            raise RuntimeError(
                f"{self.__class__.__name__} only accepts the input vectors in the shape of "
                f"(batch, length, channel) or (batch, length), but got shape={feat.shape}!"
            )

        # --- Mean and Standard Variance Initialization --- #
        # calculate the mean values of all channels of all the input utterances
        curr_means = (
            None
            if not self.mean_norm
            else torch.stack(
                [feat[i][: feat_len[i]].mean(dim=0) for i in range(batch_size)]
            )
        )

        # calculate the std values of all channels of all the input utterances
        curr_stds = (
            None
            if not self.std_norm
            else torch.clamp(
                input=torch.stack(
                    [feat[i][: feat_len[i]].std(dim=0) for i in range(batch_size)]
                ),
                min=self.clamp,
            )
        )

        # --- Perform Normalization based on Different branches --- #
        # utterance-level normalization or group-level normalization without group_ids
        if self.norm_type == "utterance":
            feat = feat - curr_means.unsqueeze(1) if curr_means is not None else feat
            feat = feat / curr_stds.unsqueeze(1) if curr_stds is not None else feat

        # global-level & batch-level & group-level normalization (with group_ids)
        else:
            # only gather the batch sizes from other processes in the DDP model of training
            all_batch_size = None
            if self.training:
                all_batch_size = (
                    self.gather_scalars(batch_size, feat.device)
                    if self.distributed
                    else batch_size
                )

            # group-level normalization with tensor group_ids (input utterances belong to different groups)
            if self.norm_type == "group" and isinstance(group_ids, torch.Tensor):
                # only update the mean and std of the specific group during training
                if self.training:
                    # DDP mode
                    if self.distributed:
                        # gather all the group ids from other processes
                        all_group_ids = self.gather_vectors(group_ids, all_batch_size)
                        # gather all the mean vectors from other processes
                        all_curr_means = (
                            None
                            if curr_means is None
                            else self.gather_matrices(curr_means, all_batch_size)
                        )
                        # gather all the std vectors from other processes
                        all_curr_stds = (
                            None
                            if curr_stds is None
                            else self.gather_matrices(curr_stds, all_batch_size)
                        )
                    # single-GPU mode
                    else:
                        # not perform gathering
                        all_group_ids = group_ids
                        all_curr_means = curr_means
                        all_curr_stds = curr_stds

                    # record the mean of all groups in the current batch
                    group_mean_dict = self.sort_data_by_group(
                        raw_data=all_curr_means, group_ids=all_group_ids
                    )

                    # record the std of all groups in the current batch
                    group_std_dict = self.sort_data_by_group(
                        raw_data=all_curr_stds, group_ids=all_group_ids
                    )

                    # register the mean, std, and batch numbers into the buffer
                    group_keys = (
                        list(group_mean_dict.keys())
                        if group_mean_dict is not None
                        else list(group_std_dict.keys())
                    )
                    for group_id in group_keys:
                        self.register_mean_std_batch(
                            curr_aver_mean=(
                                group_mean_dict[group_id].mean(dim=0)
                                if group_mean_dict is not None
                                else None
                            ),
                            curr_aver_std=(
                                group_std_dict[group_id].mean(dim=0)
                                if group_std_dict is not None
                                else None
                            ),
                            prefix=group_id,
                            epoch=epoch,
                        )
                    # update the average mean & std of all the groups
                    # (i.e. the average distribution for unknown samples during inference)
                    self.update_aver_mean_std(epoch)

                # During training, normalize the known features by the group mean & std
                # During inference, normalize the unknown features by the average mean & std of all groups
                for i in range(batch_size):
                    group_id = group_ids[i].item() if group_ids is not None else None

                    if self.mean_norm:
                        feat[i] -= (
                            self.get_buffer("aver_mean")
                            if not hasattr(self, f"{group_id}_mean")
                            else self.get_buffer(f"{group_id}_mean")
                        )
                    if self.std_norm:
                        feat[i] /= (
                            self.get_buffer("aver_std")
                            if not hasattr(self, f"{group_id}_std")
                            else self.get_buffer(f"{group_id}_std")
                        )

            # batch-level & global-level normalization (these two scenarios share the batch-level mean & std)
            else:
                # only calculate the batch-level mean and std during training
                if self.training:
                    # gather the mean and std from the other processes in the DDP mode
                    if self.distributed:
                        # gather the sums of batch means from all the processes
                        batch_mean_sum = (
                            curr_means.sum(dim=0) if curr_means is not None else None
                        )
                        all_batch_mean_sums = (
                            self.gather_vectors(batch_mean_sum)
                            if batch_mean_sum is not None
                            else None
                        )
                        batch_mean = (
                            None
                            if all_batch_mean_sums is None
                            else all_batch_mean_sums.sum(dim=0) / all_batch_size.sum()
                        )

                        # gather the sums of batch stds from all the processes
                        batch_std_sum = (
                            curr_stds.sum(dim=0) if curr_stds is not None else None
                        )
                        all_batch_std_sums = (
                            self.gather_vectors(batch_std_sum)
                            if batch_std_sum is not None
                            else None
                        )
                        batch_std = (
                            None
                            if all_batch_std_sums is None
                            else all_batch_std_sums.sum(dim=0) / all_batch_size.sum()
                        )

                    # single-GPU mode
                    else:
                        batch_mean = (
                            curr_means.mean(dim=0) if curr_means is not None else None
                        )
                        batch_std = (
                            curr_stds.mean(dim=0) if curr_stds is not None else None
                        )

                # do nothing for batch-level mean and std during evaluation
                else:
                    batch_mean = None
                    batch_std = None

                # batch-level normalization
                if self.norm_type == "batch":
                    # normalize the input utterances by the batch mean and std during training
                    if self.training:
                        feat = feat - batch_mean if batch_mean is not None else feat
                        feat = feat / batch_std if batch_std is not None else feat
                    # normalize the input utterances by the utterance-specific mean and std during evaluation
                    else:
                        feat = (
                            feat - curr_means.unsqueeze(1)
                            if curr_means is not None
                            else feat
                        )
                        feat = (
                            feat / curr_stds.unsqueeze(1)
                            if curr_stds is not None
                            else feat
                        )

                # global-level normalization or
                # group-level normalization with str or int group_ids (input utterances belong to the same group)
                else:
                    assert self.norm_type in ["global", "group"], (
                        f"norm_type can only be one of 'utterance', 'batch', 'group', 'global', "
                        f"but got norm_type={self.norm_type}!"
                    )
                    if self.norm_type == "group":
                        assert isinstance(group_ids, (str, int)), (
                            f"If all the utterances in a single batch belong to the same group, "
                            f"you should give group_ids as a string or integer. "
                            f"But got type(group_ids)={type(group_ids)}."
                        )

                    # only update the mean and std during training
                    prefix = "global" if self.norm_type == "global" else group_ids
                    if self.training:
                        self.register_mean_std_batch(
                            curr_aver_mean=batch_mean,
                            curr_aver_std=batch_std,
                            prefix=prefix,
                            epoch=epoch,
                        )

                    # if the group_ids is given as a string or int,
                    # we assume that there are no unknown testing samples during inference
                    feat = (
                        feat - self.get_buffer(f"{prefix}_mean")
                        if curr_means is not None
                        else feat
                    )
                    feat = (
                        feat / self.get_buffer(f"{prefix}_std")
                        if curr_stds is not None
                        else feat
                    )

        return feat.squeeze(-1) if squeeze_flag else feat, feat_len

    @staticmethod
    def gather_scalars(scalar: int, device: torch.device) -> torch.LongTensor:
        # gather the input scalars
        all_scalars = [
            torch.LongTensor([0]).cuda(device)
            for _ in range(torch.distributed.get_world_size())
        ]
        torch.distributed.all_gather(
            all_scalars, torch.LongTensor([scalar]).cuda(device)
        )
        return torch.LongTensor(all_scalars)

    @staticmethod
    def gather_vectors(
        vector: torch.Tensor, all_batch_size: torch.Tensor = None
    ) -> torch.Tensor:
        # vectors of all the processes may have different length
        if all_batch_size is not None:
            curr_batch_size = all_batch_size[torch.distributed.get_rank()].item()
            max_batch_size = all_batch_size.max().item()
            if curr_batch_size < max_batch_size:
                vector = torch.cat(
                    (
                        vector,
                        torch.zeros(
                            max_batch_size - curr_batch_size,
                            dtype=vector.dtype,
                            device=vector.device,
                        ),
                    )
                )
            all_vectors = [
                torch.Tensor([0 for _ in range(max_batch_size)])
                .type_as(vector)
                .cuda(vector.device)
                for _ in range(torch.distributed.get_world_size())
            ]
        # all the vectors are equal in length
        else:
            all_vectors = [
                torch.zeros_like(vector, device=vector.device)
                for _ in range(torch.distributed.get_world_size())
            ]

        # gather the vectors from other processes to all_vectors
        torch.distributed.all_gather(all_vectors, vector)

        # remove the padding
        return (
            torch.stack(all_vectors)
            if all_batch_size is None
            else torch.cat(
                [all_vectors[i][: all_batch_size[i]] for i in range(len(all_vectors))]
            )
        )

    @staticmethod
    def gather_matrices(
        matrix: torch.Tensor, all_batch_size: torch.Tensor
    ) -> torch.Tensor:
        curr_batch_size = all_batch_size[torch.distributed.get_rank()].item()
        max_batch_size = all_batch_size.max().item()
        # padding the matrix if necessary
        if curr_batch_size < max_batch_size:
            matrix = torch.cat(
                (
                    matrix,
                    torch.zeros(
                        max_batch_size - curr_batch_size,
                        matrix.size(-1),
                        device=matrix.device,
                    ),
                )
            )

        # gather the matrices from other processes to all_matrices
        all_matrices = [
            torch.zeros_like(matrix, device=matrix.device)
            for _ in range(torch.distributed.get_world_size())
        ]
        torch.distributed.all_gather(all_matrices, matrix)

        # remove the padding
        return torch.cat(
            [all_matrices[i][: all_batch_size[i]] for i in range(len(all_matrices))]
        )

    @staticmethod
    def sort_data_by_group(raw_data: torch.Tensor, group_ids: torch.Tensor):
        """

        Args:
            raw_data:
            group_ids:

        Returns:

        """
        if raw_data is None:
            return None
        else:
            group_dict = dict()
            # loop each group id
            for i in range(group_ids.size(0)):
                curr_group = group_ids[i].item()
                # initialize the group list if not existed
                if curr_group not in group_dict.keys():
                    group_dict[curr_group] = []
                group_dict[curr_group].append(raw_data[i])
            # turn each group list into a 2d tensor
            return {
                group_id: torch.stack(group_list)
                for group_id, group_list in group_dict.items()
            }

    def register_mean_std_batch(
        self,
        curr_aver_mean: torch.Tensor,
        curr_aver_std: torch.Tensor,
        prefix: str,
        epoch: int,
    ):
        """

        Args:
            curr_aver_mean:
            curr_aver_std:
            prefix:
            epoch:

        """
        # update the observed global batch number
        if epoch is None or not hasattr(self, f"{prefix}_batch"):
            self.register_buffer(
                f"{prefix}_batch",
                torch.LongTensor([1]).cuda(device=curr_aver_mean.device),
            )
        elif epoch <= self.max_epoch_num:
            self.register_buffer(
                f"{prefix}_batch", self.get_buffer(f"{prefix}_batch") + 1
            )

        # update the observed global mean & std only in the predefined batch number
        if epoch is None or epoch <= self.max_epoch_num:
            # get the weight of the global average values
            curr_weight = 1 / self.get_buffer(f"{prefix}_batch")

            # update the observed global mean
            if self.mean_norm:
                if not hasattr(self, f"{prefix}_mean"):
                    self.register_buffer(f"{prefix}_mean", curr_aver_mean)
                else:
                    prev_aver_mean = self.get_buffer(f"{prefix}_mean")
                    self.register_buffer(
                        f"{prefix}_mean",
                        curr_weight * curr_aver_mean
                        + (1 - curr_weight) * prev_aver_mean,
                    )

            # update the observed global std
            if self.std_norm:
                if not hasattr(self, f"{prefix}_std"):
                    self.register_buffer(f"{prefix}_std", curr_aver_std)
                else:
                    prev_aver_std = self.get_buffer(f"{prefix}_std")
                    self.register_buffer(
                        f"{prefix}_std",
                        curr_weight * curr_aver_std + (1 - curr_weight) * prev_aver_std,
                    )

    def update_aver_mean_std(self, epoch: int):
        """

        Args:
            epoch:

        """
        if epoch is None or epoch <= self.max_epoch_num:
            _group_mean_num, _group_std_num = 0, 0
            _aver_mean, _aver_std = None, None
            for name, buff in self.named_buffers():
                if name.endswith("_mean"):
                    _group_mean_num += 1
                    _aver_mean = (
                        buff.clone() if _aver_mean is None else _aver_mean + buff
                    )
                elif name.endswith("_std"):
                    _group_std_num += 1
                    _aver_std = buff.clone() if _aver_std is None else _aver_std + buff

            self.register_buffer("aver_mean", _aver_mean / _group_mean_num)
            self.register_buffer("aver_std", _aver_std / _group_std_num)

    def recover(self, feat: torch.Tensor, group_ids: torch.Tensor or str or int = None):
        """

        Args:
            feat:
            group_ids:

        Returns:

        """
        assert self.norm_type not in [
            "utterance",
            "batch",
        ], "If norm_type is either 'utterance' or 'batch', the normalized features cannot be recovered."

        # global normalization or
        # group-level normalization with str or int group_ids (input utterances belong to the same group)
        if self.norm_type == "global" or (
            self.norm_type == "group" and isinstance(group_ids, (str, int))
        ):
            prefix = "global" if self.norm_type == "global" else str(group_ids)
            feat = feat * self.get_buffer(f"{prefix}_std") if self.std_norm else feat
            feat = feat + self.get_buffer(f"{prefix}_mean") if self.mean_norm else feat
        # group-level normalization with tensor group_ids (input utterances belong to different groups)
        # recover by the average mean & std when meeting an unknown group during inference
        elif self.norm_type == "group" and isinstance(group_ids, torch.Tensor):
            feat = (
                feat
                * torch.stack(
                    [
                        (
                            self.get_buffer(f"{g_id.item():d}_std")
                            if hasattr(self, f"{g_id.item():d}_std")
                            else self.get_buffer("aver_std")
                        )
                        for g_id in group_ids
                    ],
                    dim=0,
                ).unsqueeze(1)
                if self.std_norm
                else feat
            )

            feat = (
                feat
                + torch.stack(
                    [
                        (
                            self.get_buffer(f"{g_id.item():d}_mean")
                            if hasattr(self, f"{g_id.item():d}_mean")
                            else self.get_buffer("aver_mean")
                        )
                        for g_id in group_ids
                    ],
                    dim=0,
                ).unsqueeze(1)
                if self.mean_norm
                else feat
            )
        # group-level normalization with None group_ids, recover by the average mean & std
        elif self.norm_type == "group" and group_ids is None:
            feat = (
                feat * self.get_buffer("aver_std").expand(1, 1, -1)
                if self.std_norm
                else feat
            )
            feat = (
                feat + self.get_buffer("aver_mean").expand(1, 1, -1)
                if self.mean_norm
                else feat
            )
        else:
            raise RuntimeError

        return feat

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        """Lazily register all the buffer variables ending with '_batch', '_std', or
        '_mean' from state_dict to self."""
        for key in state_dict.keys():
            if key.startswith(prefix):
                input_name = key[len(prefix) :].split(".", 1)[0]

                if "_" in input_name and input_name.split("_")[-1] in [
                    "batch",
                    "std",
                    "mean",
                ]:
                    self.register_buffer(input_name, state_dict[key])
                else:
                    unexpected_keys.append(key)

    def extra_repr(self) -> str:
        return f"norm_type={self.norm_type}, mean_norm={self.mean_norm}, std_norm={self.std_norm}"

forward(feat, feat_len, group_ids=None, epoch=None)

Parameters:

Name Type Description Default
feat Tensor

(batch, length, channel) or (batch, length) The normalization will be done on the channel dimension. If the feat is in the shape of (batch, length), it will be extended to (batch, length, 1)

required
feat_len Tensor

(batch)

required
group_ids Tensor or str or int

(batch)

None
epoch int
None

Returns:

Source code in speechain/module/norm/feat_norm.py
def forward(
    self,
    feat: torch.Tensor,
    feat_len: torch.Tensor,
    group_ids: torch.Tensor or str or int = None,
    epoch: int = None,
):
    """

    Args:
        feat: (batch, length, channel) or (batch, length)
            The normalization will be done on the channel dimension.
            If the feat is in the shape of (batch, length), it will be extended to (batch, length, 1)
        feat_len: (batch)
        group_ids: (batch)
        epoch:

    Returns:

    """
    if self.norm_type == "group":
        assert group_ids is not None, (
            "You are using group-level feature normalization, but group_ids is not given. "
            "Please check 'data_cfg' in your configuration."
        )
    # para preparation
    batch_size, squeeze_flag = feat.size(0), False
    if len(feat.shape) == 2:
        feat, squeeze_flag = feat.unsqueeze(-1), True
    elif len(feat.shape) != 3:
        raise RuntimeError(
            f"{self.__class__.__name__} only accepts the input vectors in the shape of "
            f"(batch, length, channel) or (batch, length), but got shape={feat.shape}!"
        )

    # --- Mean and Standard Variance Initialization --- #
    # calculate the mean values of all channels of all the input utterances
    curr_means = (
        None
        if not self.mean_norm
        else torch.stack(
            [feat[i][: feat_len[i]].mean(dim=0) for i in range(batch_size)]
        )
    )

    # calculate the std values of all channels of all the input utterances
    curr_stds = (
        None
        if not self.std_norm
        else torch.clamp(
            input=torch.stack(
                [feat[i][: feat_len[i]].std(dim=0) for i in range(batch_size)]
            ),
            min=self.clamp,
        )
    )

    # --- Perform Normalization based on Different branches --- #
    # utterance-level normalization or group-level normalization without group_ids
    if self.norm_type == "utterance":
        feat = feat - curr_means.unsqueeze(1) if curr_means is not None else feat
        feat = feat / curr_stds.unsqueeze(1) if curr_stds is not None else feat

    # global-level & batch-level & group-level normalization (with group_ids)
    else:
        # only gather the batch sizes from other processes in the DDP model of training
        all_batch_size = None
        if self.training:
            all_batch_size = (
                self.gather_scalars(batch_size, feat.device)
                if self.distributed
                else batch_size
            )

        # group-level normalization with tensor group_ids (input utterances belong to different groups)
        if self.norm_type == "group" and isinstance(group_ids, torch.Tensor):
            # only update the mean and std of the specific group during training
            if self.training:
                # DDP mode
                if self.distributed:
                    # gather all the group ids from other processes
                    all_group_ids = self.gather_vectors(group_ids, all_batch_size)
                    # gather all the mean vectors from other processes
                    all_curr_means = (
                        None
                        if curr_means is None
                        else self.gather_matrices(curr_means, all_batch_size)
                    )
                    # gather all the std vectors from other processes
                    all_curr_stds = (
                        None
                        if curr_stds is None
                        else self.gather_matrices(curr_stds, all_batch_size)
                    )
                # single-GPU mode
                else:
                    # not perform gathering
                    all_group_ids = group_ids
                    all_curr_means = curr_means
                    all_curr_stds = curr_stds

                # record the mean of all groups in the current batch
                group_mean_dict = self.sort_data_by_group(
                    raw_data=all_curr_means, group_ids=all_group_ids
                )

                # record the std of all groups in the current batch
                group_std_dict = self.sort_data_by_group(
                    raw_data=all_curr_stds, group_ids=all_group_ids
                )

                # register the mean, std, and batch numbers into the buffer
                group_keys = (
                    list(group_mean_dict.keys())
                    if group_mean_dict is not None
                    else list(group_std_dict.keys())
                )
                for group_id in group_keys:
                    self.register_mean_std_batch(
                        curr_aver_mean=(
                            group_mean_dict[group_id].mean(dim=0)
                            if group_mean_dict is not None
                            else None
                        ),
                        curr_aver_std=(
                            group_std_dict[group_id].mean(dim=0)
                            if group_std_dict is not None
                            else None
                        ),
                        prefix=group_id,
                        epoch=epoch,
                    )
                # update the average mean & std of all the groups
                # (i.e. the average distribution for unknown samples during inference)
                self.update_aver_mean_std(epoch)

            # During training, normalize the known features by the group mean & std
            # During inference, normalize the unknown features by the average mean & std of all groups
            for i in range(batch_size):
                group_id = group_ids[i].item() if group_ids is not None else None

                if self.mean_norm:
                    feat[i] -= (
                        self.get_buffer("aver_mean")
                        if not hasattr(self, f"{group_id}_mean")
                        else self.get_buffer(f"{group_id}_mean")
                    )
                if self.std_norm:
                    feat[i] /= (
                        self.get_buffer("aver_std")
                        if not hasattr(self, f"{group_id}_std")
                        else self.get_buffer(f"{group_id}_std")
                    )

        # batch-level & global-level normalization (these two scenarios share the batch-level mean & std)
        else:
            # only calculate the batch-level mean and std during training
            if self.training:
                # gather the mean and std from the other processes in the DDP mode
                if self.distributed:
                    # gather the sums of batch means from all the processes
                    batch_mean_sum = (
                        curr_means.sum(dim=0) if curr_means is not None else None
                    )
                    all_batch_mean_sums = (
                        self.gather_vectors(batch_mean_sum)
                        if batch_mean_sum is not None
                        else None
                    )
                    batch_mean = (
                        None
                        if all_batch_mean_sums is None
                        else all_batch_mean_sums.sum(dim=0) / all_batch_size.sum()
                    )

                    # gather the sums of batch stds from all the processes
                    batch_std_sum = (
                        curr_stds.sum(dim=0) if curr_stds is not None else None
                    )
                    all_batch_std_sums = (
                        self.gather_vectors(batch_std_sum)
                        if batch_std_sum is not None
                        else None
                    )
                    batch_std = (
                        None
                        if all_batch_std_sums is None
                        else all_batch_std_sums.sum(dim=0) / all_batch_size.sum()
                    )

                # single-GPU mode
                else:
                    batch_mean = (
                        curr_means.mean(dim=0) if curr_means is not None else None
                    )
                    batch_std = (
                        curr_stds.mean(dim=0) if curr_stds is not None else None
                    )

            # do nothing for batch-level mean and std during evaluation
            else:
                batch_mean = None
                batch_std = None

            # batch-level normalization
            if self.norm_type == "batch":
                # normalize the input utterances by the batch mean and std during training
                if self.training:
                    feat = feat - batch_mean if batch_mean is not None else feat
                    feat = feat / batch_std if batch_std is not None else feat
                # normalize the input utterances by the utterance-specific mean and std during evaluation
                else:
                    feat = (
                        feat - curr_means.unsqueeze(1)
                        if curr_means is not None
                        else feat
                    )
                    feat = (
                        feat / curr_stds.unsqueeze(1)
                        if curr_stds is not None
                        else feat
                    )

            # global-level normalization or
            # group-level normalization with str or int group_ids (input utterances belong to the same group)
            else:
                assert self.norm_type in ["global", "group"], (
                    f"norm_type can only be one of 'utterance', 'batch', 'group', 'global', "
                    f"but got norm_type={self.norm_type}!"
                )
                if self.norm_type == "group":
                    assert isinstance(group_ids, (str, int)), (
                        f"If all the utterances in a single batch belong to the same group, "
                        f"you should give group_ids as a string or integer. "
                        f"But got type(group_ids)={type(group_ids)}."
                    )

                # only update the mean and std during training
                prefix = "global" if self.norm_type == "global" else group_ids
                if self.training:
                    self.register_mean_std_batch(
                        curr_aver_mean=batch_mean,
                        curr_aver_std=batch_std,
                        prefix=prefix,
                        epoch=epoch,
                    )

                # if the group_ids is given as a string or int,
                # we assume that there are no unknown testing samples during inference
                feat = (
                    feat - self.get_buffer(f"{prefix}_mean")
                    if curr_means is not None
                    else feat
                )
                feat = (
                    feat / self.get_buffer(f"{prefix}_std")
                    if curr_stds is not None
                    else feat
                )

    return feat.squeeze(-1) if squeeze_flag else feat, feat_len

module_init(norm_type='global', mean_norm=True, std_norm=True, clamp=1e-10, max_epoch_num=4)

Parameters:

Name Type Description Default
norm_type str

str The type of feature normalization. The type must be one of 'utterance', 'batch', 'group', and 'global'

'global'
mean_norm bool

bool Controls whether the feature vectors will be normalized by their means

True
std_norm bool

bool Controls whether the feature vectors will be normalized by their standard variance

True
clamp float

float Clamping threshold for the standard variance before division.

1e-10
max_epoch_num int

int The maximum number of epochs used to calculate the moving average. Usually, the value of this argument is lower than a half of the number of warmup epochs.

4
Source code in speechain/module/norm/feat_norm.py
def module_init(
    self,
    norm_type: str = "global",
    mean_norm: bool = True,
    std_norm: bool = True,
    clamp: float = 1e-10,
    max_epoch_num: int = 4,
):
    """

    Args:
        norm_type: str
            The type of feature normalization.
            The type must be one of 'utterance', 'batch', 'group', and 'global'
        mean_norm: bool
            Controls whether the feature vectors will be normalized by their means
        std_norm: bool
            Controls whether the feature vectors will be normalized by their standard variance
        clamp: float
            Clamping threshold for the standard variance before division.
        max_epoch_num: int
            The maximum number of epochs used to calculate the moving average.
            Usually, the value of this argument is lower than a half of the number of warmup epochs.

    """
    self.norm_type = norm_type
    self.mean_norm = mean_norm
    self.std_norm = std_norm
    self.clamp = clamp
    self.max_epoch_num = max_epoch_num

    if self.input_size is not None:
        self.output_size = self.input_size

recover(feat, group_ids=None)

Parameters:

Name Type Description Default
feat Tensor
required
group_ids Tensor or str or int
None

Returns:

Source code in speechain/module/norm/feat_norm.py
def recover(self, feat: torch.Tensor, group_ids: torch.Tensor or str or int = None):
    """

    Args:
        feat:
        group_ids:

    Returns:

    """
    assert self.norm_type not in [
        "utterance",
        "batch",
    ], "If norm_type is either 'utterance' or 'batch', the normalized features cannot be recovered."

    # global normalization or
    # group-level normalization with str or int group_ids (input utterances belong to the same group)
    if self.norm_type == "global" or (
        self.norm_type == "group" and isinstance(group_ids, (str, int))
    ):
        prefix = "global" if self.norm_type == "global" else str(group_ids)
        feat = feat * self.get_buffer(f"{prefix}_std") if self.std_norm else feat
        feat = feat + self.get_buffer(f"{prefix}_mean") if self.mean_norm else feat
    # group-level normalization with tensor group_ids (input utterances belong to different groups)
    # recover by the average mean & std when meeting an unknown group during inference
    elif self.norm_type == "group" and isinstance(group_ids, torch.Tensor):
        feat = (
            feat
            * torch.stack(
                [
                    (
                        self.get_buffer(f"{g_id.item():d}_std")
                        if hasattr(self, f"{g_id.item():d}_std")
                        else self.get_buffer("aver_std")
                    )
                    for g_id in group_ids
                ],
                dim=0,
            ).unsqueeze(1)
            if self.std_norm
            else feat
        )

        feat = (
            feat
            + torch.stack(
                [
                    (
                        self.get_buffer(f"{g_id.item():d}_mean")
                        if hasattr(self, f"{g_id.item():d}_mean")
                        else self.get_buffer("aver_mean")
                    )
                    for g_id in group_ids
                ],
                dim=0,
            ).unsqueeze(1)
            if self.mean_norm
            else feat
        )
    # group-level normalization with None group_ids, recover by the average mean & std
    elif self.norm_type == "group" and group_ids is None:
        feat = (
            feat * self.get_buffer("aver_std").expand(1, 1, -1)
            if self.std_norm
            else feat
        )
        feat = (
            feat + self.get_buffer("aver_mean").expand(1, 1, -1)
            if self.mean_norm
            else feat
        )
    else:
        raise RuntimeError

    return feat

register_mean_std_batch(curr_aver_mean, curr_aver_std, prefix, epoch)

Parameters:

Name Type Description Default
curr_aver_mean Tensor
required
curr_aver_std Tensor
required
prefix str
required
epoch int
required
Source code in speechain/module/norm/feat_norm.py
def register_mean_std_batch(
    self,
    curr_aver_mean: torch.Tensor,
    curr_aver_std: torch.Tensor,
    prefix: str,
    epoch: int,
):
    """

    Args:
        curr_aver_mean:
        curr_aver_std:
        prefix:
        epoch:

    """
    # update the observed global batch number
    if epoch is None or not hasattr(self, f"{prefix}_batch"):
        self.register_buffer(
            f"{prefix}_batch",
            torch.LongTensor([1]).cuda(device=curr_aver_mean.device),
        )
    elif epoch <= self.max_epoch_num:
        self.register_buffer(
            f"{prefix}_batch", self.get_buffer(f"{prefix}_batch") + 1
        )

    # update the observed global mean & std only in the predefined batch number
    if epoch is None or epoch <= self.max_epoch_num:
        # get the weight of the global average values
        curr_weight = 1 / self.get_buffer(f"{prefix}_batch")

        # update the observed global mean
        if self.mean_norm:
            if not hasattr(self, f"{prefix}_mean"):
                self.register_buffer(f"{prefix}_mean", curr_aver_mean)
            else:
                prev_aver_mean = self.get_buffer(f"{prefix}_mean")
                self.register_buffer(
                    f"{prefix}_mean",
                    curr_weight * curr_aver_mean
                    + (1 - curr_weight) * prev_aver_mean,
                )

        # update the observed global std
        if self.std_norm:
            if not hasattr(self, f"{prefix}_std"):
                self.register_buffer(f"{prefix}_std", curr_aver_std)
            else:
                prev_aver_std = self.get_buffer(f"{prefix}_std")
                self.register_buffer(
                    f"{prefix}_std",
                    curr_weight * curr_aver_std + (1 - curr_weight) * prev_aver_std,
                )

sort_data_by_group(raw_data, group_ids) staticmethod

Parameters:

Name Type Description Default
raw_data Tensor
required
group_ids Tensor
required

Returns:

Source code in speechain/module/norm/feat_norm.py
@staticmethod
def sort_data_by_group(raw_data: torch.Tensor, group_ids: torch.Tensor):
    """

    Args:
        raw_data:
        group_ids:

    Returns:

    """
    if raw_data is None:
        return None
    else:
        group_dict = dict()
        # loop each group id
        for i in range(group_ids.size(0)):
            curr_group = group_ids[i].item()
            # initialize the group list if not existed
            if curr_group not in group_dict.keys():
                group_dict[curr_group] = []
            group_dict[curr_group].append(raw_data[i])
        # turn each group list into a 2d tensor
        return {
            group_id: torch.stack(group_list)
            for group_id, group_list in group_dict.items()
        }

update_aver_mean_std(epoch)

Parameters:

Name Type Description Default
epoch int
required
Source code in speechain/module/norm/feat_norm.py
def update_aver_mean_std(self, epoch: int):
    """

    Args:
        epoch:

    """
    if epoch is None or epoch <= self.max_epoch_num:
        _group_mean_num, _group_std_num = 0, 0
        _aver_mean, _aver_std = None, None
        for name, buff in self.named_buffers():
            if name.endswith("_mean"):
                _group_mean_num += 1
                _aver_mean = (
                    buff.clone() if _aver_mean is None else _aver_mean + buff
                )
            elif name.endswith("_std"):
                _group_std_num += 1
                _aver_std = buff.clone() if _aver_std is None else _aver_std + buff

        self.register_buffer("aver_mean", _aver_mean / _group_mean_num)
        self.register_buffer("aver_std", _aver_std / _group_std_num)