Skip to content

abs

Abstract base class for all models.

Author: Heli Qi Affiliation: NAIST Date: 2022.07

Model

Bases: Module, ABC

Model is the base class for all models in this toolkit. The main job of a model includes: 1. (optional) preprocess the input batch data to the trainable format 2. calculate the model prediction results by the Module members 3. evaluate the prediction results by the Criterion members

Each model has several built-in Module members that make up the neural network structure of the model. These Module members will be initialized by the module_conf given in your configuration.

There are a built-in dictionary named init_class_dict and a built-in list named default_init_modules in the base class. init_class_dictcontains all the available initialization functions of the model parameters whiledefault_init_modules` includes the network layers that have their own initialization functions.

Attributes:

Name Type Description
init_class_dict Dict

Available parameter initialization functions

default_init_modules List

Network layers with own initialization functions

Source code in speechain/model/abs.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
class Model(torch.nn.Module, ABC):
    """
    Model is the base class for all models in this toolkit. The main job of a model includes:
        1. (optional) preprocess the input batch data to the trainable format
        2. calculate the model prediction results by the Module members
        3. evaluate the prediction results by the Criterion members

    Each model has several built-in Module members that make up the neural network structure of the model. These Module
    members will be initialized by the `module_conf` given in your configuration.

    There are a built-in dictionary named `init_class_dict` and a built-in list named `default_init_modules` in the
    base class. init_class_dict` contains all the available initialization functions of the model parameters while
    `default_init_modules` includes the network layers that have their own initialization functions.

    Attributes:
        init_class_dict (Dict): Available parameter initialization functions
        default_init_modules (List): Network layers with own initialization functions

    """

    # available parameter initialization functions
    init_class_dict: Dict = {
        "xavier": torch.nn.init.xavier_normal_,
        "xavier_normal": torch.nn.init.xavier_normal_,
        "xavier_uniform": torch.nn.init.xavier_uniform_,
        "kaiming": torch.nn.init.kaiming_normal_,
        "kaiming_normal": torch.nn.init.kaiming_normal_,
        "kaiming_uniform": torch.nn.init.kaiming_uniform_,
        "uniform": torch.nn.init.uniform_,
        "normal": torch.nn.init.normal_,
        "zeros": torch.nn.init.zeros_,
    }

    # some modules have their own parameter initialization methods
    default_init_modules: List = [  # explicitely defined
        torch.nn.Embedding,
        torch.nn.LayerNorm,
        torch.nn.BatchNorm1d,
        torch.nn.BatchNorm2d,
        PositionalEncoding,
    ]

    def __init__(
        self,
        device: torch.device,
        module_conf: Dict,
        result_path: str,
        model_conf: Dict = None,
        criterion_conf: Dict = None,
        non_blocking: bool = False,
        distributed: bool = False,
    ):
        """In this initialization function, there are two parts of
        initialization: model-specific customized initialization and model-
        independent general initialization.

        Model-specific customized initialization is done by two interface functions: module_init() and criterion_init().
        module_init() initializes the neural network structure of the model while criterion_init() initializes the
        criteria used to optimize (loss functions) and evaluate (validation metrics) the model.

        After the customized initialization, there are 3 steps for general initialization:
            1. Pretrained parameters will be loaded into your model if the key `pretrained_model` is given. Multiple
            pretrained models can be specified and each of them can be loaded into different parts of your model. The
            mismatch between the names of pretrained parameters and the parameters of your model is handled by the key
            'mapping'. The value of the key `mapping` is a dictionary where each key-value item corresponds to a mapping
            of parameter names. The key is the parameter name in the pretrained parameters while the value is the
            parameter name of your model.

            2. If `pretrained_model` is not given, the parameters of your model will be initialized by the function that
            matches your input query 'init'. Please refer to the built-in dictionary `init_class_dict` for the available
            initialization functions. If `init` is not given, the default initialization function
            `torch.nn.init.xavier_normal_` will be used to initialize your model.

            3. Finally, the specified parts of your model will be frozen if 'frozen_modules' is given. If there is only
            one frozen module, you can directly give the string of its name to 'frozen_modules' like
            'frozen_modules: {module_name}'; if there are multiple modules you want to freeze, you can give their names
            in a list as
            ```
            frozen_modules:
              - {module_name1}
              - {module_name2}
              - ...
            ```
            Moreover, the frozen granularity depends on your input `frozen_modules`.
            For example,
                1. If you give 'frozen_modules: encoder_prenet', all parameters of the prenet of your encoder will be
                frozen
                2. If you give 'frozen_modules: encoder_prenet.conv', only the convolution layers of the prenet of your
                encoder will be frozen
                3. If you give 'frozen_modules: encoder_prenet.conv.0', only the first convolution layer of the prenet
                of your encoder will be frozen
                4. If you give 'frozen_modules: encoder_prenet.conv.0.bias', only the bias vector of the first
                convolution layer of the prenet of your encoder will be frozen

        Args:
            device (torch.device):
                The computational device used for model calculation in the current GPU process.
            model_conf (Dict):
                The model configuration used for general model initialization.
            module_conf (Dict):
                The module configuration used for network structure initialization.
            criterion_conf (Dict):
                The criterion configuration used for criterion (loss functions and evaluation metrics) initialization.
        """
        super(Model, self).__init__()

        # input argument checking
        assert module_conf is not None, "module_conf cannot be None!"
        # model_conf is default to be an empty dictionary
        model_conf = dict() if model_conf is None else model_conf
        # criterion_conf is default to be an empty dictionary
        criterion_conf = dict() if criterion_conf is None else criterion_conf
        # customize_conf is default to be an empty dictionary
        if "customize_conf" not in model_conf.keys():
            model_conf["customize_conf"] = dict()

        # general argument registration
        self.non_blocking = non_blocking
        self.distributed = distributed
        self.device = device

        # snapshotting-related argument registration
        self.result_path = result_path
        if "visual_infer_conf" in model_conf.keys():
            # configuration is given as a .yaml file
            if isinstance(model_conf["visual_infer_conf"], str):
                self.visual_infer_conf = load_yaml(
                    open(parse_path_args(model_conf["visual_infer_conf"]))
                )
            # configuration is explicitly given
            elif isinstance(model_conf["visual_infer_conf"], Dict):
                self.visual_infer_conf = model_conf["visual_infer_conf"]
            else:
                raise RuntimeError(
                    "model_conf['visual_infer_conf'] must be given as either a string or a Dict."
                )
        else:
            self.visual_infer_conf = dict()

        # --- 1. Model Construction --- #
        self.module_init(**module_conf, **model_conf["customize_conf"])
        self.criterion_init(**criterion_conf)
        # initialize the bad case selection methods by the hook function
        self.bad_cases_selection = self.bad_cases_selection_init_fn()

        # --- 2.1. Pretrained Model Loading --- #
        pretrained_model = (
            model_conf["pretrained_model"]
            if "pretrained_model" in model_conf.keys()
            else None
        )
        if pretrained_model is not None:
            pretrained_model = (
                pretrained_model
                if isinstance(pretrained_model, list)
                else [pretrained_model]
            )

            for ptm in pretrained_model:
                # argument checking
                if isinstance(ptm, str):
                    ptm = dict(path=parse_path_args(ptm))
                elif isinstance(ptm, Dict):
                    assert "path" in ptm.keys(), (
                        "If model['model_conf']['pretrained_model'] is given as a Dict, "
                        "please give a key named 'path' to specify where your pretrained model is placed."
                    )
                    if os.path.exists(ptm["path"]):
                        raise RuntimeError(
                            f"The specified path of your pretrained model {ptm['path']} doesn't exist! "
                            f"Please check the input path."
                        )
                else:
                    raise TypeError(
                        f"The elements in model['model_conf']['pretrained_model'] must be either a string "
                        f"or a Dict, but got {ptm}"
                    )

                _pt_model = torch.load(
                    parse_path_args(ptm["path"]), map_location=self.device
                )
                mapping = ptm["mapping"] if "mapping" in ptm.keys() else None
                if mapping is None:
                    self.load_state_dict(
                        _pt_model,
                        strict=True if "strict" not in ptm.keys() else ptm["strict"],
                    )
                else:
                    assert isinstance(mapping, dict) and len(mapping) >= 1, (
                        f"mapping must be given as a dict and cannot be empty! "
                        f"Got type(mapping)={type(mapping)} and len(mapping)={len(mapping)}"
                    )

                    _src_modules = OrderedDict()
                    # loop each name-parameter pair in the model
                    for name, para in _pt_model.items():
                        # loop each source-target mapping pair
                        for src, tgt in mapping.items():
                            # attach '.' to the end is for making the name unique
                            src, tgt = src + ".", tgt + "."
                            # change the parameter name in the middle
                            if src in name:
                                name = name.replace(src, tgt)
                        # record the parameter no matter whether its name is modified or not
                        _src_modules[name] = para
                    self.load_state_dict(
                        _src_modules,
                        strict=True if "strict" not in ptm.keys() else ptm["strict"],
                    )

        # --- 2.2. Model Parameter Initialization --- #
        else:
            # the default initialization method is xavier (i.e. xavier_normal)
            init = model_conf["init"] if "init" in model_conf.keys() else "xavier"
            assert (
                init in self.init_class_dict.keys()
            ), f"Only the initialization methods {self.init_class_dict.keys()} are supported, but got init={init}."

            for name, para in self.named_parameters():
                # initialize all the bias vectors to zero
                if ".bias" in name and para.dim() == 1:
                    torch.nn.init.zeros_(para)
                # initialize all the weight vectors except for those of normalization layers (BatchNorm & LayerNorm)
                elif para.dim() > 1:
                    self.init_class_dict[init](para)

            # initialize the modules that have their own default init methods
            for module in self.modules():
                if isinstance(module, tuple(self.default_init_modules)):
                    module.reset_parameters()

        # --- 3. Model Parameter Freezing --- #
        frozen_modules = (
            model_conf["frozen_modules"]
            if "frozen_modules" in model_conf.keys()
            else None
        )
        if frozen_modules is not None:
            if frozen_modules != "all":
                frozen_modules = (
                    frozen_modules
                    if isinstance(frozen_modules, list)
                    else [frozen_modules]
                )

            for name, para in self.named_parameters():
                frozen_flag = False
                if frozen_modules != "all":
                    for module in frozen_modules:
                        frozen_flag = name.startswith(module + ".")
                else:
                    frozen_flag = True

                if frozen_flag:
                    para.requires_grad = False
                else:
                    raise RuntimeError(
                        f"frozen_modules: Parameters of {name} are not found in the model!"
                    )

    @abstractmethod
    def module_init(self, **kwargs) -> None:
        """The interface function that initializes the Module members of the model.
        These Module members make up the neural network structure of the model. Some
        models have their customized part that also needs to be initialization in this
        function, e.g. the tokenizer of ASR and TTS models.

        Note: This interface function must be overridden for each Model subclass.

        Args:
            **kwargs:
                The combination of the arguments in your given `module_conf` and `model_conf['customize_conf']`.
        """
        pass  # raise NotImplementedError

    @abstractmethod
    def criterion_init(self, **criterion_conf) -> None:
        """
        The interface function that initializes the Criterion members of the model. These Criterion members can be
        divided into two parts: the loss functions used for training and the evaluation metrics used for validation.

        Args:
            **criterion_conf:
                The arguments in your given `criterion_conf`.

        """
        pass  # raise NotImplementedError

    @staticmethod
    def bad_cases_selection_init_fn() -> List[List[str or int]] or None:
        """This hook function returns the default bad case selection method of each
        Model object. This default value will be referred by the _Runner_ to present the
        top-N bad cases.

        The original hook implementation in the base Model class returns None which means no default value.

        Returns: List[List[str or int]]
            The returned default value should be a list of tri-list where each tri-list is in the form of
            [`selection_metric`, `selection_mode`, `case_number`]. For example, ['wer', 'max', 50] means 50 testing
            waveforms with the largest WER will be selected.
        """
        return None

    def batch_to_cuda(
        self, data: Dict[str, torch.Tensor] or torch.Tensor
    ) -> Dict[str, torch.Tensor] or torch.Tensor:
        """The recursive function that transfers the batch data to the specified device
        in the current process.

        Args:
            data: Dict or torch.Tensor
                The input batch data. It should be either a Tensor or a Dict of Tensors. For the Dict input, the
                function itself will be called once by each Tensor element.

        Returns: Dict or torch.Tensor
            If the input is a Dict, the returned output will also be a Dict of Tensors transferred to the target device;
            If the input is a Tensor, the returned output will be its copy on the target device.
        """
        # if the data is in the form of Dict, recursively process each key-value pair
        if isinstance(data, Dict):
            return {key: self.batch_to_cuda(value) for key, value in data.items()}
        # if the data is in the form of tensor, put it on GPUs by .cuda()
        elif isinstance(data, torch.Tensor):
            return data.cuda(device=self.device, non_blocking=self.non_blocking)
        # do nothing for other types of data
        else:
            return data

    def forward(self, batch_data: Dict, epoch: int = None, **kwargs):
        """
        The general model forward function shared by all the _Model_ subclasses. This forward function has 3 steps:
            1. preprocess and transfer the batch data to GPUs
            2. obtain the model prediction results
            3. calculate the loss function and evaluate the prediction results

        For each step above, we provide interface functions for you to override and make your own implementation.

        Args:
            batch_data: Dict
                The input batch data received from the `train` or `valid` dataloader object in the experimental
                pipeline.
            epoch: int = None
                The number of the current epoch. Used for real-time model visualization and model prediction.
            **kwargs:
                The additional arguments for real-time model visualization. If given, the code will go through the model
                visualization branch.

        Returns:
            In the training branch, the loss functions and evaluation metrics will be returned each of which is in the
            form of a Dict.
            In the validation branch, only the evaluation metrics will be returned.
            In the visualization branch, the model snapshots on the given validation instance will be returned.

        """
        # --- 1. Batch Data Preprocessing and GPU transferring --- #
        # --- data preparation below is shared by all the three branches: training, validation, and visualization --- #
        # preprocess the batch data if needed
        batch_data = self.batch_preprocess_fn(batch_data)

        # put the batch data onto GPUs
        batch_data = self.batch_to_cuda(batch_data)

        # --- 2.1. Model Visualization Branch --- #
        # if there are additional arguments other than batch_data and epoch, the visualization branch is activated
        if len(kwargs) != 0:
            return self.visualize(epoch=epoch, **batch_data, **kwargs)

        # --- 2.2. Model Forward Calculation --- #
        # --- model forward is shared by both the training and validation branches --- #
        # context function used when doing the loss backward for efficient gradient accumulation in the DDP mode
        forward_context = nullcontext if self.training else torch.inference_mode
        with forward_context():
            try:
                # Feed the input batch into the model and get the outputs, copy.deepcopy() here is for the data safety
                model_outputs = self.module_forward(
                    epoch=epoch, **copy.deepcopy(batch_data)
                )
            except Exception as e:
                if not self.distributed:
                    raise e
                else:
                    skip_flag_list = torch.LongTensor(
                        [False for _ in range(torch.distributed.get_world_size())]
                    ).cuda(self.device)
                    skip_flag = torch.LongTensor([True]).cuda(self.device)
                    # as long as one node meets an error, all nodes will skip the current step at the same time
                    torch.distributed.all_gather_into_tensor(skip_flag_list, skip_flag)
                    if skip_flag_list.sum() >= 1:
                        raise e
            else:
                if self.distributed:
                    skip_flag_list = torch.LongTensor(
                        [False for _ in range(torch.distributed.get_world_size())]
                    ).cuda(self.device)
                    skip_flag = torch.LongTensor([False]).cuda(self.device)
                    # as long as one node meets an error, all nodes will skip the current step at the same time
                    torch.distributed.all_gather_into_tensor(skip_flag_list, skip_flag)
                    if skip_flag_list.sum() >= 1:
                        raise RuntimeError(
                            "Other ranks meet errors during model forwarding, "
                            "so this rank will also skip the current step!"
                        )

        # copy.deepcopy() cannot receive the non-leaf nodes in the computation graph (model_outputs). Since
        # model_outputs cannot be detached from the graph (gradients necessary), copy.deepcopy() is not used below.
        def combine_input_output(_batch_data: Dict, _model_outputs: Dict):
            combination, batch_keys = dict(), list(_batch_data.keys())
            # if the input batch data is in the form of Dict, it means there are multiple dataloaders
            if isinstance(_batch_data[batch_keys[0]], Dict):
                for key in batch_keys:
                    combination[key] = dict(**_batch_data[key], **_model_outputs[key])
            # if the input batch data is in the form of Tensor, it means there is only one dataloader.
            else:
                combination.update(_batch_data)
                combination.update(_model_outputs)
            return combination

        # --- 3.1. Model Training Branch --- #
        if self.training:
            # In the training stage, both the trainable losses and non-trainable metrics will be returned
            losses, metrics = self.criterion_forward(
                **combine_input_output(batch_data, model_outputs)
            )
            metrics.update(self.get_recordable_para())

            # post-checking for training losses, they must be trainable tensors
            assert sum(
                [
                    isinstance(loss, torch.Tensor) and loss.requires_grad
                    for loss in losses.values()
                ]
            ) == len(losses), "Training losses must be trainable tensors!"
            # post-checking for validation metrics, they must be either non-trainable tensors or other datatypes
            assert sum(
                [
                    not isinstance(metric, torch.Tensor) or not metric.requires_grad
                    for metric in metrics.values()
                ]
            ) == len(
                metrics
            ), "Validation metrics must be either non-trainable tensors or other datatypes!"

            # the non-trainable metrics will be averaged across all the processes in the distributed mode
            if self.distributed:
                metrics = self.aver_metrics_across_procs(metrics, batch_data)
            return losses, metrics

        # --- 3.2. Model Validation Branch --- #
        else:
            # In the validation stage, only the non-trainable metrics will be returned
            with torch.inference_mode():
                metrics = self.criterion_forward(
                    **combine_input_output(batch_data, model_outputs)
                )
            metrics.update(self.get_recordable_para())

            # post-checking for validation metrics, they must be either non-trainable tensors or other datatypes
            assert sum(
                [
                    not isinstance(metric, torch.Tensor) or not metric.requires_grad
                    for metric in metrics.values()
                ]
            ) == len(
                metrics
            ), "Validation metrics must be either non-trainable tensors or other datatypes!"

            # the non-trainable metrics will be averaged across all the processes in the distributed mode
            if self.distributed:
                metrics = self.aver_metrics_across_procs(metrics, batch_data)
            return metrics

    def batch_preprocess_fn(self, batch_data: Dict) -> Dict:
        """This hook function does the preprocessing for the input batch data before
        using them in self.model_forward(). This function is not mandatory to be
        overridden and the original implementation in the base Model class does the
        tensor transformation for the string-like data in batch_data (i.e., text and
        spk_ids).

        Note: the key names in the returned Dict should match the argument names in self.model_forward().

        Args:
            batch_data: Dict
                The raw data of the input batch to be preprocessed in this hook function.

        Returns: Dict
            The processed data of the input batch that is ready to be used in `self.model_forward()`.
        """

        def process_strings(data_dict: Dict):
            """Turn the text and speaker strings into tensors and get their lengths."""
            # --- Process the Text String and its Length --- #
            if "text" in data_dict.keys():
                if isinstance(data_dict["text"], List):
                    data_dict["text"], data_dict["text_len"] = text2tensor_and_len(
                        text_list=data_dict["text"],
                        text2tensor_func=self.tokenizer.text2tensor,
                        ignore_idx=self.tokenizer.ignore_idx,
                    )
                else:
                    assert isinstance(data_dict["text"], torch.Tensor)

            # --- Process the Speaker ID String --- #
            if "spk_ids" in data_dict.keys():
                if isinstance(data_dict["spk_ids"], List):
                    if hasattr(self, "spk2idx"):
                        data_dict["spk_ids"] = spk2tensor(
                            spk_list=data_dict["spk_ids"], spk2idx_dict=self.spk2idx
                        )
                elif not isinstance(data_dict["spk_ids"], torch.Tensor):
                    raise TypeError

            return data_dict

        # check whether the batch_data is made by multiple dataloaders
        leaf_flags = [not isinstance(value, Dict) for value in batch_data.values()]
        if sum(leaf_flags) == 0:
            return {key: process_strings(value) for key, value in batch_data.items()}
        elif sum(leaf_flags) == len(batch_data):
            return process_strings(batch_data)
        else:
            raise RuntimeError("Wrong composition of batch_data!")

    def aver_metrics_across_procs(
        self, metrics: Dict[str, torch.Tensor], batch_data: Dict
    ) -> Dict[str, torch.Tensor]:
        """This function averages the evaluation metrics across all GPU processes in the
        DDP mode for model distribution.

        Args:
            metrics: Dict[str, torch.Tensor]
                The evaluation metrics to be averaged across all GPU processes.
            batch_data: Dict
                The input batch data used to calculate the batch size for averaging evaluation metrics.

        Returns: Dict[str, torch.Tensor]
            The evaluation metrics _Dict_ after averaging. The key names remain the same.
        """

        def get_batch_size(input_dict: Dict):
            _batch_size = None
            for value in input_dict.values():
                # len() considers all types of array: torch.Tensor, np.ndarray, List, ...
                if _batch_size is None:
                    _batch_size = len(value)
                else:
                    assert _batch_size == len(value)
            return _batch_size

        # check the batch size
        multi_flag = sum(
            [isinstance(value, Dict) for value in batch_data.values()]
        ) == len(batch_data)
        # we take the summation of all data-labels pairs in a single batch made by multiple dataloaders
        if multi_flag:
            batch_size = sum([get_batch_size(value) for value in batch_data.values()])
        else:
            batch_size = get_batch_size(batch_data)
        batch_size = torch.tensor([batch_size], dtype=torch.long, device=self.device)

        # sum up all the weighed metrics at rank no.0
        for key in metrics.keys():
            # each metric should be one-dimensional scalar
            if metrics[key].dim() == 0:
                metrics[key] = metrics[key][None]
            elif metrics[key].dim() != 1:
                raise RuntimeError(
                    f"Each metric value must be one-dimensional scalar, "
                    f"but got metrics[{key}]={metrics[key]}!"
                )

            # batch_size acts as the weight for each metric value in the current process
            metrics[key] *= batch_size.type(metrics[key].dtype)
            # sum up the weighted metric values at rank no.0
            torch.distributed.reduce(
                metrics[key], dst=0, op=torch.distributed.ReduceOp.SUM
            )

        # sum up the batch size across at rank no.0 to get the overall batch size
        torch.distributed.reduce(batch_size, dst=0, op=torch.distributed.ReduceOp.SUM)
        if torch.distributed.get_rank() == 0:
            for key in metrics.keys():
                # turn the object value to the overall batch-level
                metrics[key] /= batch_size.type(metrics[key].dtype)

        return metrics

    @abstractmethod
    def module_forward(self, epoch: int = None, **batch_data) -> Dict:
        """
        This function forwards the input batch data by all _Module_ members.
        Note:
            1. This interface function must be overridden for each Model subclass.
            2. The argument names should match the key names in the returned Dict of `self.batch_preprocess_fn()`.
            3. The key names in the returned Dict should match the argument names of `self.loss_calculation()` and
            `self.metrics_calculation()`.

        Args:
            epoch:
            **batch_data:
                Processed data of the input batch received from `self.batch_preprocess_fn()`.

        Returns: Dict
            Prediction results (logits) of the model on the input batch data.
            Some intermediate results (e.g., attention matrices) can also be returned for later use.

        """
        pass  # raise NotImplementedError

    @abstractmethod
    def criterion_forward(
        self, **kwargs
    ) -> (Dict[str, torch.Tensor], Dict[str, torch.Tensor]) or Dict[str, torch.Tensor]:
        """This interface function is activated after `self.model_forward()`. It
        receives the model prediction results from `self.model_forward()` and input
        batch data from `self.batch_preprocess_fn()`.

        Args:
            **kwargs:
                The combination of the returned arguments from `self.batch_preprocess_fn()` and `self.model_forward()`.

        Returns: (Dict[str, torch.Tensor], Dict[str, torch.Tensor]) or Dict[str, torch.Tensor]
            The returned values should be different for the training and validation branches.
            1. For training, two Dict[str, torch.Tensor] should be returned where the first one contains all the
            trainable training losses for optimization and the second one contains all the non-trainable evaluation
            metrics used to record the training status.
            2. For validation, only one Dict[str, torch.Tensor] should be returned which contains all the non-trainable
            evaluation metrics used to record the validation status.
        """
        pass  # raise NotImplementedError

    def get_recordable_para(self) -> Dict[str, torch.Tensor]:
        """Recursively retrieves the recordable parameters from the module's sub-
        modules.

        Returns:
            Dict[str, torch.Tensor]: A dictionary mapping the parameter names to their corresponding tensor values.
        """

        def recur_get_module_recordable_para(curr_node, prefix_list: List[str] = None):
            if prefix_list is None:
                prefix_list = []
            if isinstance(curr_node, Dict):
                _output = dict()
                for _key, _value in curr_node.items():
                    _output.update(
                        recur_get_module_recordable_para(_value, prefix_list + [_key])
                    )
                return _output
            else:
                if curr_node is None:
                    return {}
                elif isinstance(curr_node, torch.Tensor):
                    return {"_".join(prefix_list): curr_node.clone().detach()}
                else:
                    raise RuntimeError

        output = dict()
        for key, value in self._modules.items():
            if isinstance(value, Module):
                output.update(
                    recur_get_module_recordable_para(value.get_recordable_para(), [key])
                )
        return output

    def matrix_snapshot(
        self,
        vis_logs: List,
        hypo_attention: Dict,
        subfolder_names: List[str] or str,
        epoch: int,
    ):
        """Used by the abstract function visualize() to make the snapshot materials for
        attention matrices."""
        if isinstance(subfolder_names, str):
            subfolder_names = [subfolder_names]
        keys = list(hypo_attention.keys())

        # process the input data by different data types
        if isinstance(hypo_attention[keys[0]], Dict):
            for key, value in hypo_attention.items():
                self.matrix_snapshot(
                    vis_logs=vis_logs,
                    hypo_attention=value,
                    subfolder_names=subfolder_names + [key],
                    epoch=epoch,
                )

        # snapshot the information in the materials
        elif isinstance(hypo_attention[keys[0]], np.ndarray):
            vis_logs.append(
                dict(
                    plot_type="matrix",
                    materials=hypo_attention,
                    epoch=epoch,
                    sep_save=False,
                    data_save=True,
                    subfolder_names=subfolder_names,
                )
            )

    def attention_reshape(self, hypo_attention: Dict, prefix_list: List = None) -> Dict:
        """Used by the abstract function visualize() to reshape the attention matrices
        before matrix_snapshot()."""
        if prefix_list is None:
            prefix_list = []

        # process the input data by different data types
        if isinstance(hypo_attention, Dict):
            return {
                key: self.attention_reshape(value, prefix_list + [key])
                for key, value in hypo_attention.items()
            }
        elif isinstance(hypo_attention, List):
            return {
                str(index - len(hypo_attention)): self.attention_reshape(
                    hypo_attention[index],
                    prefix_list + [str(index - len(hypo_attention))],
                )
                for index in range(len(hypo_attention) - 1, -1, -1)
            }
        elif isinstance(hypo_attention, torch.Tensor):
            hypo_attention = hypo_attention.squeeze()
            if hypo_attention.is_cuda:
                hypo_attention = hypo_attention.detach().cpu()

            if hypo_attention.dim() == 2:
                return {".".join(prefix_list + [str(0)]): hypo_attention.numpy()}
            elif hypo_attention.dim() == 3:
                return {
                    ".".join(prefix_list + [str(index)]): element.numpy()
                    for index, element in enumerate(hypo_attention)
                }
            else:
                raise RuntimeError

    @abstractmethod
    def visualize(self, epoch: int, sample_index: str, **valid_sample):
        """

        Args:
            epoch:
            sample_index:
            **valid_sample:

        Returns:

        """
        pass  # raise NotImplementedError

    def evaluate(self, test_batch: Dict, infer_conf: Dict):
        """
        The shared evaluation function by all _Model_ subclasses. This evaluation function has 2 steps:
            1. preprocess and transfer the batch data to GPUs
            2. calculate the inference results

        For each step above, we provide interface functions for you to override and make your own implementation.

        Args:
            test_batch: Dict
                The input batch data received from the `test` dataloader object in the experimental pipeline.
            infer_conf: Dict
                The configuration used for model inference.

        Returns:
            A Dict of the inference results where each key-value item corresponds to one evaluation metric you want to
            save to the disk.

        """
        # preprocess the batch data if needed
        test_batch = self.batch_preprocess_fn(test_batch)

        # put the batch data onto GPUs
        test_batch = self.batch_to_cuda(test_batch)

        # get the inference results
        evaluate_results = self.inference(infer_conf=infer_conf, **test_batch)
        if (
            hasattr(self, "instance_report_cache")
            and self.instance_report_cache is not None
        ):
            evaluate_results["instance_reports.md"] = self.instance_report_cache
            self.instance_report_cache = None

        # post-check the format of evaluate_results
        if isinstance(evaluate_results, Dict):
            for key, value in evaluate_results.items():
                if "format" not in value.keys() or "content" not in value.keys():
                    raise RuntimeError(
                        "Each element of the returned value of self.inference() must contain the keys "
                        "named both 'format' and 'content'!"
                    )
        else:
            raise RuntimeError(
                f"The returned value of self.inference() must be a Dict, "
                f"but got {type(evaluate_results)}!"
            )
        return evaluate_results

    @abstractmethod
    def inference(
        self, infer_conf: Dict, **kwargs
    ) -> Dict[str, Dict[str, str or List]]:
        """This function receives the test data and test configuration. The inference
        results will be packaged into a Dict[str, Dict] which is passed to TestMonitor
        for disk storage. The returned Dict should be in the form of ``` dict(
        {file_name}=dict( format={file_format},

                content={file_content}
            )
        )
        ```
        The first-level key is used to decide the name of the meta file as `idx2{file_name}`. Its value is also a Dict
        and there must be two keys in this sub-Dict: 'format' and 'content'. The configuration of the sub-Dict is
        different for different file formats:

            1. For pure text metadata files, the value of 'format' must be 'txt' and the value of 'content' must be a
            list of Python built-in data type (i.e.,. int, float, str, bool, ...).
            Each line of the file `idx2{file_name}` will be made up of the index of a test data instance and its
            metadata value in the `content` List which are separated by a blank.
            For example,
            `dict(cer=dict(format='txt', content=[0.1, 0.2, 0.3]))` will create a pure text file named 'idx2cer' which
            looks like
            ```
            {test_index1} 0.1
            {test_index2} 0.2
            {test_index3} 0.3
            ```
            Note: if the first-level key ends with '.md', there will not be 'idx2' attached at the beginning of the
            file name.

            2. For audio files, the value of 'format' must be either 'wav' or 'flac' and the value of 'content' must be
            a list of array-like data type (e.g. numpy.ndarry, torch.Tensor, ...).
            Moreover, there must be an additional key named 'sample_rate' to indicate the sampling rate of the waveforms
            to be saved in audio files.
            There will be a folder named `{file_name}` that contains all the audio files and a pure text file named
            `idx2{file_name}` that contains the absolute paths of all the saved audio files.
            For example,
            `dict(wav=dict(format='flac', content=[np_arr1, np_arr2, np_arr3]))` will create a folder named 'wav' and
            a pure text file named 'idx2wav' in the same directory. The file 'idx2wav' looks like:
            ```
            {test_index1} /x/xx/wav/{test_index1}.flac
            {test_index2} /x/xx/wav/{test_index2}.flac
            {test_index3} /x/xx/wav/{test_index3}.flac
            ```
            where `/x/xx/` is your result path given in your `exp_cfg`.

            3. For binary files, the value of 'format' in the sub-Dict must be 'npy' and the value of 'content' must be
            a list of numpy.ndarry (torch.Tensor is not supported).
            There will be a folder named `{file_name}` that contains all the .npy files and a pure text file
            named `idx2{file_name}` that contains the absolute paths of all the saved binary files.
            For example,
            `dict(feat=dict(format='npy', content=[np_arr1, np_arr2, np_arr3]))`
            will create a folder named 'feat' and a pure text file named 'idx2feat'. The 'idx2feat' file is like:
            ```
            {test_index1} /x/xx/feat/{test_index1}.npy
            {test_index2} /x/xx/feat/{test_index2}.npy
            {test_index3} /x/xx/feat/{test_index3}.npy
            ```
            where `/x/xx/` is your result path given in your `exp_cfg`.
        """
        pass  # raise NotImplementedError

    def register_instance_reports(
        self, md_list_dict: Dict[str, List], extra_string_list: List[str] = None
    ):
        """

        Args:
            md_list_dict:
            extra_string_list:

        Returns:

        """
        # --- 1. Arguments Checking --- #
        if extra_string_list is not None:
            assert isinstance(extra_string_list, List)

        ele_len = []
        for value in md_list_dict.values():
            assert isinstance(value, List)
            if extra_string_list is not None:
                assert len(value) == len(extra_string_list)
            ele_len.append(len(value))

        if len(set(ele_len)) == 1:
            ele_len = ele_len[0]
        else:
            raise RuntimeError

        # --- 2. Generate .md Instance Report for the current step --- #
        instance_reports = []
        for i in range(ele_len):
            ele_dict = {
                key: value[i] if isinstance(value[i], str) else str(value[i])
                for key, value in md_list_dict.items()
            }
            _curr_report = "\n\n" + get_list_strings(ele_dict) + "\n"

            if extra_string_list is not None:
                _curr_report += extra_string_list[i] + "\n"
            instance_reports.append(_curr_report)

        self.instance_report_cache = dict(format="txt", content=instance_reports)

__init__(device, module_conf, result_path, model_conf=None, criterion_conf=None, non_blocking=False, distributed=False)

In this initialization function, there are two parts of initialization: model-specific customized initialization and model- independent general initialization.

Model-specific customized initialization is done by two interface functions: module_init() and criterion_init(). module_init() initializes the neural network structure of the model while criterion_init() initializes the criteria used to optimize (loss functions) and evaluate (validation metrics) the model.

After the customized initialization, there are 3 steps for general initialization: 1. Pretrained parameters will be loaded into your model if the key pretrained_model is given. Multiple pretrained models can be specified and each of them can be loaded into different parts of your model. The mismatch between the names of pretrained parameters and the parameters of your model is handled by the key 'mapping'. The value of the key mapping is a dictionary where each key-value item corresponds to a mapping of parameter names. The key is the parameter name in the pretrained parameters while the value is the parameter name of your model.

2. If `pretrained_model` is not given, the parameters of your model will be initialized by the function that
matches your input query 'init'. Please refer to the built-in dictionary `init_class_dict` for the available
initialization functions. If `init` is not given, the default initialization function
`torch.nn.init.xavier_normal_` will be used to initialize your model.

3. Finally, the specified parts of your model will be frozen if 'frozen_modules' is given. If there is only
one frozen module, you can directly give the string of its name to 'frozen_modules' like
'frozen_modules: {module_name}'; if there are multiple modules you want to freeze, you can give their names
in a list as
```
frozen_modules:
  - {module_name1}
  - {module_name2}
  - ...
```
Moreover, the frozen granularity depends on your input `frozen_modules`.
For example,
    1. If you give 'frozen_modules: encoder_prenet', all parameters of the prenet of your encoder will be
    frozen
    2. If you give 'frozen_modules: encoder_prenet.conv', only the convolution layers of the prenet of your
    encoder will be frozen
    3. If you give 'frozen_modules: encoder_prenet.conv.0', only the first convolution layer of the prenet
    of your encoder will be frozen
    4. If you give 'frozen_modules: encoder_prenet.conv.0.bias', only the bias vector of the first
    convolution layer of the prenet of your encoder will be frozen

Parameters:

Name Type Description Default
device device

The computational device used for model calculation in the current GPU process.

required
model_conf Dict

The model configuration used for general model initialization.

None
module_conf Dict

The module configuration used for network structure initialization.

required
criterion_conf Dict

The criterion configuration used for criterion (loss functions and evaluation metrics) initialization.

None
Source code in speechain/model/abs.py
def __init__(
    self,
    device: torch.device,
    module_conf: Dict,
    result_path: str,
    model_conf: Dict = None,
    criterion_conf: Dict = None,
    non_blocking: bool = False,
    distributed: bool = False,
):
    """In this initialization function, there are two parts of
    initialization: model-specific customized initialization and model-
    independent general initialization.

    Model-specific customized initialization is done by two interface functions: module_init() and criterion_init().
    module_init() initializes the neural network structure of the model while criterion_init() initializes the
    criteria used to optimize (loss functions) and evaluate (validation metrics) the model.

    After the customized initialization, there are 3 steps for general initialization:
        1. Pretrained parameters will be loaded into your model if the key `pretrained_model` is given. Multiple
        pretrained models can be specified and each of them can be loaded into different parts of your model. The
        mismatch between the names of pretrained parameters and the parameters of your model is handled by the key
        'mapping'. The value of the key `mapping` is a dictionary where each key-value item corresponds to a mapping
        of parameter names. The key is the parameter name in the pretrained parameters while the value is the
        parameter name of your model.

        2. If `pretrained_model` is not given, the parameters of your model will be initialized by the function that
        matches your input query 'init'. Please refer to the built-in dictionary `init_class_dict` for the available
        initialization functions. If `init` is not given, the default initialization function
        `torch.nn.init.xavier_normal_` will be used to initialize your model.

        3. Finally, the specified parts of your model will be frozen if 'frozen_modules' is given. If there is only
        one frozen module, you can directly give the string of its name to 'frozen_modules' like
        'frozen_modules: {module_name}'; if there are multiple modules you want to freeze, you can give their names
        in a list as
        ```
        frozen_modules:
          - {module_name1}
          - {module_name2}
          - ...
        ```
        Moreover, the frozen granularity depends on your input `frozen_modules`.
        For example,
            1. If you give 'frozen_modules: encoder_prenet', all parameters of the prenet of your encoder will be
            frozen
            2. If you give 'frozen_modules: encoder_prenet.conv', only the convolution layers of the prenet of your
            encoder will be frozen
            3. If you give 'frozen_modules: encoder_prenet.conv.0', only the first convolution layer of the prenet
            of your encoder will be frozen
            4. If you give 'frozen_modules: encoder_prenet.conv.0.bias', only the bias vector of the first
            convolution layer of the prenet of your encoder will be frozen

    Args:
        device (torch.device):
            The computational device used for model calculation in the current GPU process.
        model_conf (Dict):
            The model configuration used for general model initialization.
        module_conf (Dict):
            The module configuration used for network structure initialization.
        criterion_conf (Dict):
            The criterion configuration used for criterion (loss functions and evaluation metrics) initialization.
    """
    super(Model, self).__init__()

    # input argument checking
    assert module_conf is not None, "module_conf cannot be None!"
    # model_conf is default to be an empty dictionary
    model_conf = dict() if model_conf is None else model_conf
    # criterion_conf is default to be an empty dictionary
    criterion_conf = dict() if criterion_conf is None else criterion_conf
    # customize_conf is default to be an empty dictionary
    if "customize_conf" not in model_conf.keys():
        model_conf["customize_conf"] = dict()

    # general argument registration
    self.non_blocking = non_blocking
    self.distributed = distributed
    self.device = device

    # snapshotting-related argument registration
    self.result_path = result_path
    if "visual_infer_conf" in model_conf.keys():
        # configuration is given as a .yaml file
        if isinstance(model_conf["visual_infer_conf"], str):
            self.visual_infer_conf = load_yaml(
                open(parse_path_args(model_conf["visual_infer_conf"]))
            )
        # configuration is explicitly given
        elif isinstance(model_conf["visual_infer_conf"], Dict):
            self.visual_infer_conf = model_conf["visual_infer_conf"]
        else:
            raise RuntimeError(
                "model_conf['visual_infer_conf'] must be given as either a string or a Dict."
            )
    else:
        self.visual_infer_conf = dict()

    # --- 1. Model Construction --- #
    self.module_init(**module_conf, **model_conf["customize_conf"])
    self.criterion_init(**criterion_conf)
    # initialize the bad case selection methods by the hook function
    self.bad_cases_selection = self.bad_cases_selection_init_fn()

    # --- 2.1. Pretrained Model Loading --- #
    pretrained_model = (
        model_conf["pretrained_model"]
        if "pretrained_model" in model_conf.keys()
        else None
    )
    if pretrained_model is not None:
        pretrained_model = (
            pretrained_model
            if isinstance(pretrained_model, list)
            else [pretrained_model]
        )

        for ptm in pretrained_model:
            # argument checking
            if isinstance(ptm, str):
                ptm = dict(path=parse_path_args(ptm))
            elif isinstance(ptm, Dict):
                assert "path" in ptm.keys(), (
                    "If model['model_conf']['pretrained_model'] is given as a Dict, "
                    "please give a key named 'path' to specify where your pretrained model is placed."
                )
                if os.path.exists(ptm["path"]):
                    raise RuntimeError(
                        f"The specified path of your pretrained model {ptm['path']} doesn't exist! "
                        f"Please check the input path."
                    )
            else:
                raise TypeError(
                    f"The elements in model['model_conf']['pretrained_model'] must be either a string "
                    f"or a Dict, but got {ptm}"
                )

            _pt_model = torch.load(
                parse_path_args(ptm["path"]), map_location=self.device
            )
            mapping = ptm["mapping"] if "mapping" in ptm.keys() else None
            if mapping is None:
                self.load_state_dict(
                    _pt_model,
                    strict=True if "strict" not in ptm.keys() else ptm["strict"],
                )
            else:
                assert isinstance(mapping, dict) and len(mapping) >= 1, (
                    f"mapping must be given as a dict and cannot be empty! "
                    f"Got type(mapping)={type(mapping)} and len(mapping)={len(mapping)}"
                )

                _src_modules = OrderedDict()
                # loop each name-parameter pair in the model
                for name, para in _pt_model.items():
                    # loop each source-target mapping pair
                    for src, tgt in mapping.items():
                        # attach '.' to the end is for making the name unique
                        src, tgt = src + ".", tgt + "."
                        # change the parameter name in the middle
                        if src in name:
                            name = name.replace(src, tgt)
                    # record the parameter no matter whether its name is modified or not
                    _src_modules[name] = para
                self.load_state_dict(
                    _src_modules,
                    strict=True if "strict" not in ptm.keys() else ptm["strict"],
                )

    # --- 2.2. Model Parameter Initialization --- #
    else:
        # the default initialization method is xavier (i.e. xavier_normal)
        init = model_conf["init"] if "init" in model_conf.keys() else "xavier"
        assert (
            init in self.init_class_dict.keys()
        ), f"Only the initialization methods {self.init_class_dict.keys()} are supported, but got init={init}."

        for name, para in self.named_parameters():
            # initialize all the bias vectors to zero
            if ".bias" in name and para.dim() == 1:
                torch.nn.init.zeros_(para)
            # initialize all the weight vectors except for those of normalization layers (BatchNorm & LayerNorm)
            elif para.dim() > 1:
                self.init_class_dict[init](para)

        # initialize the modules that have their own default init methods
        for module in self.modules():
            if isinstance(module, tuple(self.default_init_modules)):
                module.reset_parameters()

    # --- 3. Model Parameter Freezing --- #
    frozen_modules = (
        model_conf["frozen_modules"]
        if "frozen_modules" in model_conf.keys()
        else None
    )
    if frozen_modules is not None:
        if frozen_modules != "all":
            frozen_modules = (
                frozen_modules
                if isinstance(frozen_modules, list)
                else [frozen_modules]
            )

        for name, para in self.named_parameters():
            frozen_flag = False
            if frozen_modules != "all":
                for module in frozen_modules:
                    frozen_flag = name.startswith(module + ".")
            else:
                frozen_flag = True

            if frozen_flag:
                para.requires_grad = False
            else:
                raise RuntimeError(
                    f"frozen_modules: Parameters of {name} are not found in the model!"
                )

attention_reshape(hypo_attention, prefix_list=None)

Used by the abstract function visualize() to reshape the attention matrices before matrix_snapshot().

Source code in speechain/model/abs.py
def attention_reshape(self, hypo_attention: Dict, prefix_list: List = None) -> Dict:
    """Used by the abstract function visualize() to reshape the attention matrices
    before matrix_snapshot()."""
    if prefix_list is None:
        prefix_list = []

    # process the input data by different data types
    if isinstance(hypo_attention, Dict):
        return {
            key: self.attention_reshape(value, prefix_list + [key])
            for key, value in hypo_attention.items()
        }
    elif isinstance(hypo_attention, List):
        return {
            str(index - len(hypo_attention)): self.attention_reshape(
                hypo_attention[index],
                prefix_list + [str(index - len(hypo_attention))],
            )
            for index in range(len(hypo_attention) - 1, -1, -1)
        }
    elif isinstance(hypo_attention, torch.Tensor):
        hypo_attention = hypo_attention.squeeze()
        if hypo_attention.is_cuda:
            hypo_attention = hypo_attention.detach().cpu()

        if hypo_attention.dim() == 2:
            return {".".join(prefix_list + [str(0)]): hypo_attention.numpy()}
        elif hypo_attention.dim() == 3:
            return {
                ".".join(prefix_list + [str(index)]): element.numpy()
                for index, element in enumerate(hypo_attention)
            }
        else:
            raise RuntimeError

aver_metrics_across_procs(metrics, batch_data)

This function averages the evaluation metrics across all GPU processes in the DDP mode for model distribution.

Parameters:

Name Type Description Default
metrics Dict[str, Tensor]

Dict[str, torch.Tensor] The evaluation metrics to be averaged across all GPU processes.

required
batch_data Dict

Dict The input batch data used to calculate the batch size for averaging evaluation metrics.

required

Dict[str, torch.Tensor]

Type Description
Dict[str, Tensor]

The evaluation metrics Dict after averaging. The key names remain the same.

Source code in speechain/model/abs.py
def aver_metrics_across_procs(
    self, metrics: Dict[str, torch.Tensor], batch_data: Dict
) -> Dict[str, torch.Tensor]:
    """This function averages the evaluation metrics across all GPU processes in the
    DDP mode for model distribution.

    Args:
        metrics: Dict[str, torch.Tensor]
            The evaluation metrics to be averaged across all GPU processes.
        batch_data: Dict
            The input batch data used to calculate the batch size for averaging evaluation metrics.

    Returns: Dict[str, torch.Tensor]
        The evaluation metrics _Dict_ after averaging. The key names remain the same.
    """

    def get_batch_size(input_dict: Dict):
        _batch_size = None
        for value in input_dict.values():
            # len() considers all types of array: torch.Tensor, np.ndarray, List, ...
            if _batch_size is None:
                _batch_size = len(value)
            else:
                assert _batch_size == len(value)
        return _batch_size

    # check the batch size
    multi_flag = sum(
        [isinstance(value, Dict) for value in batch_data.values()]
    ) == len(batch_data)
    # we take the summation of all data-labels pairs in a single batch made by multiple dataloaders
    if multi_flag:
        batch_size = sum([get_batch_size(value) for value in batch_data.values()])
    else:
        batch_size = get_batch_size(batch_data)
    batch_size = torch.tensor([batch_size], dtype=torch.long, device=self.device)

    # sum up all the weighed metrics at rank no.0
    for key in metrics.keys():
        # each metric should be one-dimensional scalar
        if metrics[key].dim() == 0:
            metrics[key] = metrics[key][None]
        elif metrics[key].dim() != 1:
            raise RuntimeError(
                f"Each metric value must be one-dimensional scalar, "
                f"but got metrics[{key}]={metrics[key]}!"
            )

        # batch_size acts as the weight for each metric value in the current process
        metrics[key] *= batch_size.type(metrics[key].dtype)
        # sum up the weighted metric values at rank no.0
        torch.distributed.reduce(
            metrics[key], dst=0, op=torch.distributed.ReduceOp.SUM
        )

    # sum up the batch size across at rank no.0 to get the overall batch size
    torch.distributed.reduce(batch_size, dst=0, op=torch.distributed.ReduceOp.SUM)
    if torch.distributed.get_rank() == 0:
        for key in metrics.keys():
            # turn the object value to the overall batch-level
            metrics[key] /= batch_size.type(metrics[key].dtype)

    return metrics

bad_cases_selection_init_fn() staticmethod

This hook function returns the default bad case selection method of each Model object. This default value will be referred by the Runner to present the top-N bad cases.

The original hook implementation in the base Model class returns None which means no default value.

List[List[str or int]]

Type Description
List[List[str or int]] or None

The returned default value should be a list of tri-list where each tri-list is in the form of

List[List[str or int]] or None

[selection_metric, selection_mode, case_number]. For example, ['wer', 'max', 50] means 50 testing

List[List[str or int]] or None

waveforms with the largest WER will be selected.

Source code in speechain/model/abs.py
@staticmethod
def bad_cases_selection_init_fn() -> List[List[str or int]] or None:
    """This hook function returns the default bad case selection method of each
    Model object. This default value will be referred by the _Runner_ to present the
    top-N bad cases.

    The original hook implementation in the base Model class returns None which means no default value.

    Returns: List[List[str or int]]
        The returned default value should be a list of tri-list where each tri-list is in the form of
        [`selection_metric`, `selection_mode`, `case_number`]. For example, ['wer', 'max', 50] means 50 testing
        waveforms with the largest WER will be selected.
    """
    return None

batch_preprocess_fn(batch_data)

This hook function does the preprocessing for the input batch data before using them in self.model_forward(). This function is not mandatory to be overridden and the original implementation in the base Model class does the tensor transformation for the string-like data in batch_data (i.e., text and spk_ids).

Note: the key names in the returned Dict should match the argument names in self.model_forward().

Parameters:

Name Type Description Default
batch_data Dict

Dict The raw data of the input batch to be preprocessed in this hook function.

required

Dict

Type Description
Dict

The processed data of the input batch that is ready to be used in self.model_forward().

Source code in speechain/model/abs.py
def batch_preprocess_fn(self, batch_data: Dict) -> Dict:
    """This hook function does the preprocessing for the input batch data before
    using them in self.model_forward(). This function is not mandatory to be
    overridden and the original implementation in the base Model class does the
    tensor transformation for the string-like data in batch_data (i.e., text and
    spk_ids).

    Note: the key names in the returned Dict should match the argument names in self.model_forward().

    Args:
        batch_data: Dict
            The raw data of the input batch to be preprocessed in this hook function.

    Returns: Dict
        The processed data of the input batch that is ready to be used in `self.model_forward()`.
    """

    def process_strings(data_dict: Dict):
        """Turn the text and speaker strings into tensors and get their lengths."""
        # --- Process the Text String and its Length --- #
        if "text" in data_dict.keys():
            if isinstance(data_dict["text"], List):
                data_dict["text"], data_dict["text_len"] = text2tensor_and_len(
                    text_list=data_dict["text"],
                    text2tensor_func=self.tokenizer.text2tensor,
                    ignore_idx=self.tokenizer.ignore_idx,
                )
            else:
                assert isinstance(data_dict["text"], torch.Tensor)

        # --- Process the Speaker ID String --- #
        if "spk_ids" in data_dict.keys():
            if isinstance(data_dict["spk_ids"], List):
                if hasattr(self, "spk2idx"):
                    data_dict["spk_ids"] = spk2tensor(
                        spk_list=data_dict["spk_ids"], spk2idx_dict=self.spk2idx
                    )
            elif not isinstance(data_dict["spk_ids"], torch.Tensor):
                raise TypeError

        return data_dict

    # check whether the batch_data is made by multiple dataloaders
    leaf_flags = [not isinstance(value, Dict) for value in batch_data.values()]
    if sum(leaf_flags) == 0:
        return {key: process_strings(value) for key, value in batch_data.items()}
    elif sum(leaf_flags) == len(batch_data):
        return process_strings(batch_data)
    else:
        raise RuntimeError("Wrong composition of batch_data!")

batch_to_cuda(data)

The recursive function that transfers the batch data to the specified device in the current process.

Parameters:

Name Type Description Default
data Dict[str, Tensor] or Tensor

Dict or torch.Tensor The input batch data. It should be either a Tensor or a Dict of Tensors. For the Dict input, the function itself will be called once by each Tensor element.

required

Dict or torch.Tensor

Type Description
Dict[str, Tensor] or Tensor

If the input is a Dict, the returned output will also be a Dict of Tensors transferred to the target device;

Dict[str, Tensor] or Tensor

If the input is a Tensor, the returned output will be its copy on the target device.

Source code in speechain/model/abs.py
def batch_to_cuda(
    self, data: Dict[str, torch.Tensor] or torch.Tensor
) -> Dict[str, torch.Tensor] or torch.Tensor:
    """The recursive function that transfers the batch data to the specified device
    in the current process.

    Args:
        data: Dict or torch.Tensor
            The input batch data. It should be either a Tensor or a Dict of Tensors. For the Dict input, the
            function itself will be called once by each Tensor element.

    Returns: Dict or torch.Tensor
        If the input is a Dict, the returned output will also be a Dict of Tensors transferred to the target device;
        If the input is a Tensor, the returned output will be its copy on the target device.
    """
    # if the data is in the form of Dict, recursively process each key-value pair
    if isinstance(data, Dict):
        return {key: self.batch_to_cuda(value) for key, value in data.items()}
    # if the data is in the form of tensor, put it on GPUs by .cuda()
    elif isinstance(data, torch.Tensor):
        return data.cuda(device=self.device, non_blocking=self.non_blocking)
    # do nothing for other types of data
    else:
        return data

criterion_forward(**kwargs) abstractmethod

This interface function is activated after self.model_forward(). It receives the model prediction results from self.model_forward() and input batch data from self.batch_preprocess_fn().

Parameters:

Name Type Description Default
**kwargs

The combination of the returned arguments from self.batch_preprocess_fn() and self.model_forward().

{}

(Dict[str, torch.Tensor], Dict[str, torch.Tensor]) or Dict[str, torch.Tensor]

Type Description
(Dict[str, Tensor], Dict[str, Tensor]) or Dict[str, Tensor]

The returned values should be different for the training and validation branches.

(Dict[str, Tensor], Dict[str, Tensor]) or Dict[str, Tensor]
  1. For training, two Dict[str, torch.Tensor] should be returned where the first one contains all the
(Dict[str, Tensor], Dict[str, Tensor]) or Dict[str, Tensor]

trainable training losses for optimization and the second one contains all the non-trainable evaluation

(Dict[str, Tensor], Dict[str, Tensor]) or Dict[str, Tensor]

metrics used to record the training status.

(Dict[str, Tensor], Dict[str, Tensor]) or Dict[str, Tensor]
  1. For validation, only one Dict[str, torch.Tensor] should be returned which contains all the non-trainable
(Dict[str, Tensor], Dict[str, Tensor]) or Dict[str, Tensor]

evaluation metrics used to record the validation status.

Source code in speechain/model/abs.py
@abstractmethod
def criterion_forward(
    self, **kwargs
) -> (Dict[str, torch.Tensor], Dict[str, torch.Tensor]) or Dict[str, torch.Tensor]:
    """This interface function is activated after `self.model_forward()`. It
    receives the model prediction results from `self.model_forward()` and input
    batch data from `self.batch_preprocess_fn()`.

    Args:
        **kwargs:
            The combination of the returned arguments from `self.batch_preprocess_fn()` and `self.model_forward()`.

    Returns: (Dict[str, torch.Tensor], Dict[str, torch.Tensor]) or Dict[str, torch.Tensor]
        The returned values should be different for the training and validation branches.
        1. For training, two Dict[str, torch.Tensor] should be returned where the first one contains all the
        trainable training losses for optimization and the second one contains all the non-trainable evaluation
        metrics used to record the training status.
        2. For validation, only one Dict[str, torch.Tensor] should be returned which contains all the non-trainable
        evaluation metrics used to record the validation status.
    """
    pass  # raise NotImplementedError

criterion_init(**criterion_conf) abstractmethod

The interface function that initializes the Criterion members of the model. These Criterion members can be divided into two parts: the loss functions used for training and the evaluation metrics used for validation.

Parameters:

Name Type Description Default
**criterion_conf

The arguments in your given criterion_conf.

{}
Source code in speechain/model/abs.py
@abstractmethod
def criterion_init(self, **criterion_conf) -> None:
    """
    The interface function that initializes the Criterion members of the model. These Criterion members can be
    divided into two parts: the loss functions used for training and the evaluation metrics used for validation.

    Args:
        **criterion_conf:
            The arguments in your given `criterion_conf`.

    """
    pass  # raise NotImplementedError

evaluate(test_batch, infer_conf)

The shared evaluation function by all Model subclasses. This evaluation function has 2 steps: 1. preprocess and transfer the batch data to GPUs 2. calculate the inference results

For each step above, we provide interface functions for you to override and make your own implementation.

Parameters:

Name Type Description Default
test_batch Dict

Dict The input batch data received from the test dataloader object in the experimental pipeline.

required
infer_conf Dict

Dict The configuration used for model inference.

required

Returns:

Type Description

A Dict of the inference results where each key-value item corresponds to one evaluation metric you want to

save to the disk.

Source code in speechain/model/abs.py
def evaluate(self, test_batch: Dict, infer_conf: Dict):
    """
    The shared evaluation function by all _Model_ subclasses. This evaluation function has 2 steps:
        1. preprocess and transfer the batch data to GPUs
        2. calculate the inference results

    For each step above, we provide interface functions for you to override and make your own implementation.

    Args:
        test_batch: Dict
            The input batch data received from the `test` dataloader object in the experimental pipeline.
        infer_conf: Dict
            The configuration used for model inference.

    Returns:
        A Dict of the inference results where each key-value item corresponds to one evaluation metric you want to
        save to the disk.

    """
    # preprocess the batch data if needed
    test_batch = self.batch_preprocess_fn(test_batch)

    # put the batch data onto GPUs
    test_batch = self.batch_to_cuda(test_batch)

    # get the inference results
    evaluate_results = self.inference(infer_conf=infer_conf, **test_batch)
    if (
        hasattr(self, "instance_report_cache")
        and self.instance_report_cache is not None
    ):
        evaluate_results["instance_reports.md"] = self.instance_report_cache
        self.instance_report_cache = None

    # post-check the format of evaluate_results
    if isinstance(evaluate_results, Dict):
        for key, value in evaluate_results.items():
            if "format" not in value.keys() or "content" not in value.keys():
                raise RuntimeError(
                    "Each element of the returned value of self.inference() must contain the keys "
                    "named both 'format' and 'content'!"
                )
    else:
        raise RuntimeError(
            f"The returned value of self.inference() must be a Dict, "
            f"but got {type(evaluate_results)}!"
        )
    return evaluate_results

forward(batch_data, epoch=None, **kwargs)

The general model forward function shared by all the Model subclasses. This forward function has 3 steps: 1. preprocess and transfer the batch data to GPUs 2. obtain the model prediction results 3. calculate the loss function and evaluate the prediction results

For each step above, we provide interface functions for you to override and make your own implementation.

Parameters:

Name Type Description Default
batch_data Dict

Dict The input batch data received from the train or valid dataloader object in the experimental pipeline.

required
epoch int

int = None The number of the current epoch. Used for real-time model visualization and model prediction.

None
**kwargs

The additional arguments for real-time model visualization. If given, the code will go through the model visualization branch.

{}

Returns:

Type Description

In the training branch, the loss functions and evaluation metrics will be returned each of which is in the

form of a Dict.

In the validation branch, only the evaluation metrics will be returned.

In the visualization branch, the model snapshots on the given validation instance will be returned.

Source code in speechain/model/abs.py
def forward(self, batch_data: Dict, epoch: int = None, **kwargs):
    """
    The general model forward function shared by all the _Model_ subclasses. This forward function has 3 steps:
        1. preprocess and transfer the batch data to GPUs
        2. obtain the model prediction results
        3. calculate the loss function and evaluate the prediction results

    For each step above, we provide interface functions for you to override and make your own implementation.

    Args:
        batch_data: Dict
            The input batch data received from the `train` or `valid` dataloader object in the experimental
            pipeline.
        epoch: int = None
            The number of the current epoch. Used for real-time model visualization and model prediction.
        **kwargs:
            The additional arguments for real-time model visualization. If given, the code will go through the model
            visualization branch.

    Returns:
        In the training branch, the loss functions and evaluation metrics will be returned each of which is in the
        form of a Dict.
        In the validation branch, only the evaluation metrics will be returned.
        In the visualization branch, the model snapshots on the given validation instance will be returned.

    """
    # --- 1. Batch Data Preprocessing and GPU transferring --- #
    # --- data preparation below is shared by all the three branches: training, validation, and visualization --- #
    # preprocess the batch data if needed
    batch_data = self.batch_preprocess_fn(batch_data)

    # put the batch data onto GPUs
    batch_data = self.batch_to_cuda(batch_data)

    # --- 2.1. Model Visualization Branch --- #
    # if there are additional arguments other than batch_data and epoch, the visualization branch is activated
    if len(kwargs) != 0:
        return self.visualize(epoch=epoch, **batch_data, **kwargs)

    # --- 2.2. Model Forward Calculation --- #
    # --- model forward is shared by both the training and validation branches --- #
    # context function used when doing the loss backward for efficient gradient accumulation in the DDP mode
    forward_context = nullcontext if self.training else torch.inference_mode
    with forward_context():
        try:
            # Feed the input batch into the model and get the outputs, copy.deepcopy() here is for the data safety
            model_outputs = self.module_forward(
                epoch=epoch, **copy.deepcopy(batch_data)
            )
        except Exception as e:
            if not self.distributed:
                raise e
            else:
                skip_flag_list = torch.LongTensor(
                    [False for _ in range(torch.distributed.get_world_size())]
                ).cuda(self.device)
                skip_flag = torch.LongTensor([True]).cuda(self.device)
                # as long as one node meets an error, all nodes will skip the current step at the same time
                torch.distributed.all_gather_into_tensor(skip_flag_list, skip_flag)
                if skip_flag_list.sum() >= 1:
                    raise e
        else:
            if self.distributed:
                skip_flag_list = torch.LongTensor(
                    [False for _ in range(torch.distributed.get_world_size())]
                ).cuda(self.device)
                skip_flag = torch.LongTensor([False]).cuda(self.device)
                # as long as one node meets an error, all nodes will skip the current step at the same time
                torch.distributed.all_gather_into_tensor(skip_flag_list, skip_flag)
                if skip_flag_list.sum() >= 1:
                    raise RuntimeError(
                        "Other ranks meet errors during model forwarding, "
                        "so this rank will also skip the current step!"
                    )

    # copy.deepcopy() cannot receive the non-leaf nodes in the computation graph (model_outputs). Since
    # model_outputs cannot be detached from the graph (gradients necessary), copy.deepcopy() is not used below.
    def combine_input_output(_batch_data: Dict, _model_outputs: Dict):
        combination, batch_keys = dict(), list(_batch_data.keys())
        # if the input batch data is in the form of Dict, it means there are multiple dataloaders
        if isinstance(_batch_data[batch_keys[0]], Dict):
            for key in batch_keys:
                combination[key] = dict(**_batch_data[key], **_model_outputs[key])
        # if the input batch data is in the form of Tensor, it means there is only one dataloader.
        else:
            combination.update(_batch_data)
            combination.update(_model_outputs)
        return combination

    # --- 3.1. Model Training Branch --- #
    if self.training:
        # In the training stage, both the trainable losses and non-trainable metrics will be returned
        losses, metrics = self.criterion_forward(
            **combine_input_output(batch_data, model_outputs)
        )
        metrics.update(self.get_recordable_para())

        # post-checking for training losses, they must be trainable tensors
        assert sum(
            [
                isinstance(loss, torch.Tensor) and loss.requires_grad
                for loss in losses.values()
            ]
        ) == len(losses), "Training losses must be trainable tensors!"
        # post-checking for validation metrics, they must be either non-trainable tensors or other datatypes
        assert sum(
            [
                not isinstance(metric, torch.Tensor) or not metric.requires_grad
                for metric in metrics.values()
            ]
        ) == len(
            metrics
        ), "Validation metrics must be either non-trainable tensors or other datatypes!"

        # the non-trainable metrics will be averaged across all the processes in the distributed mode
        if self.distributed:
            metrics = self.aver_metrics_across_procs(metrics, batch_data)
        return losses, metrics

    # --- 3.2. Model Validation Branch --- #
    else:
        # In the validation stage, only the non-trainable metrics will be returned
        with torch.inference_mode():
            metrics = self.criterion_forward(
                **combine_input_output(batch_data, model_outputs)
            )
        metrics.update(self.get_recordable_para())

        # post-checking for validation metrics, they must be either non-trainable tensors or other datatypes
        assert sum(
            [
                not isinstance(metric, torch.Tensor) or not metric.requires_grad
                for metric in metrics.values()
            ]
        ) == len(
            metrics
        ), "Validation metrics must be either non-trainable tensors or other datatypes!"

        # the non-trainable metrics will be averaged across all the processes in the distributed mode
        if self.distributed:
            metrics = self.aver_metrics_across_procs(metrics, batch_data)
        return metrics

get_recordable_para()

Recursively retrieves the recordable parameters from the module's sub- modules.

Returns:

Type Description
Dict[str, Tensor]

Dict[str, torch.Tensor]: A dictionary mapping the parameter names to their corresponding tensor values.

Source code in speechain/model/abs.py
def get_recordable_para(self) -> Dict[str, torch.Tensor]:
    """Recursively retrieves the recordable parameters from the module's sub-
    modules.

    Returns:
        Dict[str, torch.Tensor]: A dictionary mapping the parameter names to their corresponding tensor values.
    """

    def recur_get_module_recordable_para(curr_node, prefix_list: List[str] = None):
        if prefix_list is None:
            prefix_list = []
        if isinstance(curr_node, Dict):
            _output = dict()
            for _key, _value in curr_node.items():
                _output.update(
                    recur_get_module_recordable_para(_value, prefix_list + [_key])
                )
            return _output
        else:
            if curr_node is None:
                return {}
            elif isinstance(curr_node, torch.Tensor):
                return {"_".join(prefix_list): curr_node.clone().detach()}
            else:
                raise RuntimeError

    output = dict()
    for key, value in self._modules.items():
        if isinstance(value, Module):
            output.update(
                recur_get_module_recordable_para(value.get_recordable_para(), [key])
            )
    return output

inference(infer_conf, **kwargs) abstractmethod

This function receives the test data and test configuration. The inference results will be packaged into a Dict[str, Dict] which is passed to TestMonitor for disk storage. The returned Dict should be in the form of ``` dict( {file_name}=dict( format={file_format},

    content={file_content}
)

) `` The first-level key is used to decide the name of the meta file asidx2{file_name}`. Its value is also a Dict and there must be two keys in this sub-Dict: 'format' and 'content'. The configuration of the sub-Dict is different for different file formats:

1. For pure text metadata files, the value of 'format' must be 'txt' and the value of 'content' must be a
list of Python built-in data type (i.e.,. int, float, str, bool, ...).
Each line of the file `idx2{file_name}` will be made up of the index of a test data instance and its
metadata value in the `content` List which are separated by a blank.
For example,
`dict(cer=dict(format='txt', content=[0.1, 0.2, 0.3]))` will create a pure text file named 'idx2cer' which
looks like
```
{test_index1} 0.1
{test_index2} 0.2
{test_index3} 0.3
```
Note: if the first-level key ends with '.md', there will not be 'idx2' attached at the beginning of the
file name.

2. For audio files, the value of 'format' must be either 'wav' or 'flac' and the value of 'content' must be
a list of array-like data type (e.g. numpy.ndarry, torch.Tensor, ...).
Moreover, there must be an additional key named 'sample_rate' to indicate the sampling rate of the waveforms
to be saved in audio files.
There will be a folder named `{file_name}` that contains all the audio files and a pure text file named
`idx2{file_name}` that contains the absolute paths of all the saved audio files.
For example,
`dict(wav=dict(format='flac', content=[np_arr1, np_arr2, np_arr3]))` will create a folder named 'wav' and
a pure text file named 'idx2wav' in the same directory. The file 'idx2wav' looks like:
```
{test_index1} /x/xx/wav/{test_index1}.flac
{test_index2} /x/xx/wav/{test_index2}.flac
{test_index3} /x/xx/wav/{test_index3}.flac
```
where `/x/xx/` is your result path given in your `exp_cfg`.

3. For binary files, the value of 'format' in the sub-Dict must be 'npy' and the value of 'content' must be
a list of numpy.ndarry (torch.Tensor is not supported).
There will be a folder named `{file_name}` that contains all the .npy files and a pure text file
named `idx2{file_name}` that contains the absolute paths of all the saved binary files.
For example,
`dict(feat=dict(format='npy', content=[np_arr1, np_arr2, np_arr3]))`
will create a folder named 'feat' and a pure text file named 'idx2feat'. The 'idx2feat' file is like:
```
{test_index1} /x/xx/feat/{test_index1}.npy
{test_index2} /x/xx/feat/{test_index2}.npy
{test_index3} /x/xx/feat/{test_index3}.npy
```
where `/x/xx/` is your result path given in your `exp_cfg`.
Source code in speechain/model/abs.py
@abstractmethod
def inference(
    self, infer_conf: Dict, **kwargs
) -> Dict[str, Dict[str, str or List]]:
    """This function receives the test data and test configuration. The inference
    results will be packaged into a Dict[str, Dict] which is passed to TestMonitor
    for disk storage. The returned Dict should be in the form of ``` dict(
    {file_name}=dict( format={file_format},

            content={file_content}
        )
    )
    ```
    The first-level key is used to decide the name of the meta file as `idx2{file_name}`. Its value is also a Dict
    and there must be two keys in this sub-Dict: 'format' and 'content'. The configuration of the sub-Dict is
    different for different file formats:

        1. For pure text metadata files, the value of 'format' must be 'txt' and the value of 'content' must be a
        list of Python built-in data type (i.e.,. int, float, str, bool, ...).
        Each line of the file `idx2{file_name}` will be made up of the index of a test data instance and its
        metadata value in the `content` List which are separated by a blank.
        For example,
        `dict(cer=dict(format='txt', content=[0.1, 0.2, 0.3]))` will create a pure text file named 'idx2cer' which
        looks like
        ```
        {test_index1} 0.1
        {test_index2} 0.2
        {test_index3} 0.3
        ```
        Note: if the first-level key ends with '.md', there will not be 'idx2' attached at the beginning of the
        file name.

        2. For audio files, the value of 'format' must be either 'wav' or 'flac' and the value of 'content' must be
        a list of array-like data type (e.g. numpy.ndarry, torch.Tensor, ...).
        Moreover, there must be an additional key named 'sample_rate' to indicate the sampling rate of the waveforms
        to be saved in audio files.
        There will be a folder named `{file_name}` that contains all the audio files and a pure text file named
        `idx2{file_name}` that contains the absolute paths of all the saved audio files.
        For example,
        `dict(wav=dict(format='flac', content=[np_arr1, np_arr2, np_arr3]))` will create a folder named 'wav' and
        a pure text file named 'idx2wav' in the same directory. The file 'idx2wav' looks like:
        ```
        {test_index1} /x/xx/wav/{test_index1}.flac
        {test_index2} /x/xx/wav/{test_index2}.flac
        {test_index3} /x/xx/wav/{test_index3}.flac
        ```
        where `/x/xx/` is your result path given in your `exp_cfg`.

        3. For binary files, the value of 'format' in the sub-Dict must be 'npy' and the value of 'content' must be
        a list of numpy.ndarry (torch.Tensor is not supported).
        There will be a folder named `{file_name}` that contains all the .npy files and a pure text file
        named `idx2{file_name}` that contains the absolute paths of all the saved binary files.
        For example,
        `dict(feat=dict(format='npy', content=[np_arr1, np_arr2, np_arr3]))`
        will create a folder named 'feat' and a pure text file named 'idx2feat'. The 'idx2feat' file is like:
        ```
        {test_index1} /x/xx/feat/{test_index1}.npy
        {test_index2} /x/xx/feat/{test_index2}.npy
        {test_index3} /x/xx/feat/{test_index3}.npy
        ```
        where `/x/xx/` is your result path given in your `exp_cfg`.
    """
    pass  # raise NotImplementedError

matrix_snapshot(vis_logs, hypo_attention, subfolder_names, epoch)

Used by the abstract function visualize() to make the snapshot materials for attention matrices.

Source code in speechain/model/abs.py
def matrix_snapshot(
    self,
    vis_logs: List,
    hypo_attention: Dict,
    subfolder_names: List[str] or str,
    epoch: int,
):
    """Used by the abstract function visualize() to make the snapshot materials for
    attention matrices."""
    if isinstance(subfolder_names, str):
        subfolder_names = [subfolder_names]
    keys = list(hypo_attention.keys())

    # process the input data by different data types
    if isinstance(hypo_attention[keys[0]], Dict):
        for key, value in hypo_attention.items():
            self.matrix_snapshot(
                vis_logs=vis_logs,
                hypo_attention=value,
                subfolder_names=subfolder_names + [key],
                epoch=epoch,
            )

    # snapshot the information in the materials
    elif isinstance(hypo_attention[keys[0]], np.ndarray):
        vis_logs.append(
            dict(
                plot_type="matrix",
                materials=hypo_attention,
                epoch=epoch,
                sep_save=False,
                data_save=True,
                subfolder_names=subfolder_names,
            )
        )

module_forward(epoch=None, **batch_data) abstractmethod

This function forwards the input batch data by all Module members. Note: 1. This interface function must be overridden for each Model subclass. 2. The argument names should match the key names in the returned Dict of self.batch_preprocess_fn(). 3. The key names in the returned Dict should match the argument names of self.loss_calculation() and self.metrics_calculation().

Parameters:

Name Type Description Default
epoch int
None
**batch_data

Processed data of the input batch received from self.batch_preprocess_fn().

{}

Dict

Type Description
Dict

Prediction results (logits) of the model on the input batch data.

Dict

Some intermediate results (e.g., attention matrices) can also be returned for later use.

Source code in speechain/model/abs.py
@abstractmethod
def module_forward(self, epoch: int = None, **batch_data) -> Dict:
    """
    This function forwards the input batch data by all _Module_ members.
    Note:
        1. This interface function must be overridden for each Model subclass.
        2. The argument names should match the key names in the returned Dict of `self.batch_preprocess_fn()`.
        3. The key names in the returned Dict should match the argument names of `self.loss_calculation()` and
        `self.metrics_calculation()`.

    Args:
        epoch:
        **batch_data:
            Processed data of the input batch received from `self.batch_preprocess_fn()`.

    Returns: Dict
        Prediction results (logits) of the model on the input batch data.
        Some intermediate results (e.g., attention matrices) can also be returned for later use.

    """
    pass  # raise NotImplementedError

module_init(**kwargs) abstractmethod

The interface function that initializes the Module members of the model. These Module members make up the neural network structure of the model. Some models have their customized part that also needs to be initialization in this function, e.g. the tokenizer of ASR and TTS models.

Note: This interface function must be overridden for each Model subclass.

Parameters:

Name Type Description Default
**kwargs

The combination of the arguments in your given module_conf and model_conf['customize_conf'].

{}
Source code in speechain/model/abs.py
@abstractmethod
def module_init(self, **kwargs) -> None:
    """The interface function that initializes the Module members of the model.
    These Module members make up the neural network structure of the model. Some
    models have their customized part that also needs to be initialization in this
    function, e.g. the tokenizer of ASR and TTS models.

    Note: This interface function must be overridden for each Model subclass.

    Args:
        **kwargs:
            The combination of the arguments in your given `module_conf` and `model_conf['customize_conf']`.
    """
    pass  # raise NotImplementedError

register_instance_reports(md_list_dict, extra_string_list=None)

Parameters:

Name Type Description Default
md_list_dict Dict[str, List]
required
extra_string_list List[str]
None

Returns:

Source code in speechain/model/abs.py
def register_instance_reports(
    self, md_list_dict: Dict[str, List], extra_string_list: List[str] = None
):
    """

    Args:
        md_list_dict:
        extra_string_list:

    Returns:

    """
    # --- 1. Arguments Checking --- #
    if extra_string_list is not None:
        assert isinstance(extra_string_list, List)

    ele_len = []
    for value in md_list_dict.values():
        assert isinstance(value, List)
        if extra_string_list is not None:
            assert len(value) == len(extra_string_list)
        ele_len.append(len(value))

    if len(set(ele_len)) == 1:
        ele_len = ele_len[0]
    else:
        raise RuntimeError

    # --- 2. Generate .md Instance Report for the current step --- #
    instance_reports = []
    for i in range(ele_len):
        ele_dict = {
            key: value[i] if isinstance(value[i], str) else str(value[i])
            for key, value in md_list_dict.items()
        }
        _curr_report = "\n\n" + get_list_strings(ele_dict) + "\n"

        if extra_string_list is not None:
            _curr_report += extra_string_list[i] + "\n"
        instance_reports.append(_curr_report)

    self.instance_report_cache = dict(format="txt", content=instance_reports)

visualize(epoch, sample_index, **valid_sample) abstractmethod

Parameters:

Name Type Description Default
epoch int
required
sample_index str
required
**valid_sample
{}

Returns:

Source code in speechain/model/abs.py
@abstractmethod
def visualize(self, epoch: int, sample_index: str, **valid_sample):
    """

    Args:
        epoch:
        sample_index:
        **valid_sample:

    Returns:

    """
    pass  # raise NotImplementedError