Model
Model is the hub of this part where different Module and Criterion objects can be freely assembled to create a model. Model encapsulates the general model-related services and provides sufficient interface functions for you to override to customize your own models.
Table of Contents
- Configuration File Format
- Model Library
- API Document
- Supported Models
- How to Freeze a Specific Part of your Model
- How to Initialize your Model by the Pretrained Models
Configuration File Format
The configuration of your model is given in train_cfg. The configuration format is shown below.
model:
model_type: {file_name}.{class_name}
model_conf:
init: {init_function}
frozen_modules:
- {frozen_module1}
- {frozen_module2}
- ...
pretrained_model:
- path: {model_path1}
mapping:
{src_name1}: {tgt_name1}
{src_name2}: {tgt_name2}
...
- path: {model_path2}
- ...
visual_infer_conf:
...
customize_conf:
{customize_arg1}: {arg_value1}
{customize_arg2}: {arg_value2}
...
module_conf:
...
criterion_conf:
...
-
model_type is used as the query to pick up your target Model subclass in
{SPEECHAIN_ROOT}/speechain/model/
for model initialization. Your given query should be in the form of{file_name}.{class_name}
, e.g.,asr.ASR
means the subclassASR
in{SPEECHAIN_ROOT}/speechain/model/asr.py
. -
model_conf contains the general configuration of your model. It is made up of the following 5 parts:
-
init indicates the function used to initialize the parameters of your model before training. The available initialization functions are shown in the keys of the built-in dictionary
init_class_dict
.
For more details about the available initialization functions, please refer to the built-in dictionaryinit_class_dict
. -
frozen_modules contains the names of the modules that don't need to be updated during training. If a list of module names is given, all those modules will be frozen.
-
pretrained_model contains the pretrained models you would like to load into your model as the initial parameters. If a list of pretrained models is given, all those pretrained models will be used to initialize your model.
- path indicates where the pretrained model file is placed.
- mapping is a dictionary used to solve the mismatch between the parameter names of the pretrained model and the model you want to train. Each key-value item solves a name mismatch where the key is the name in the pretrained model and the value is the name in the model to be trained.
-
visual_infer_conf contains the inference configuration you want to use for model visualization during training. This argument is default to be an empty dictionary which means the default inference configuration of each model will be used.
For more details, please refer to the docstring ofinference()
of each Model subclass. -
customize_conf will be used to initialize the main body of the model in the interface function module_init().
For more details about the argument setting, please refer to the README.md of each Model subclass.
-
-
module_conf contains all the configuration about the module initialization. These configuration arguments will be used to initialize the network structure of the model in the interface function module_init().
For more details about the argument setting, please refer to the README.md of each Model subclass. -
criterion_conf contains all the information about the criterion initialization. These configuration arguments will be used to initialize all the criteria of the model in the interfance function criterion_init().
For more details about the argument setting, please refer to the README.md of each Model subclass.
👆Back to the table of contents
Model Library
/speechain
/model
/abs.py # Abstract Model class. Base of all Model implementations.
/asr.py # All the model implementations of ASR.
/tts.py # All the model implementations of TTS.
👆Back to the table of contents
API Document
-
Non-overridable backbone functions:
-
Overridable interface functions:
👆Back to the table of contents
speechain.model.abs.Model
speechain.model.abs.Model is the base class for all models in this toolkit. The main job of a model includes:
- (optional) preprocess the input batch data to the trainable format
- calculate the model prediction results by the Module members
- evaluate the prediction results by the Criterion members
Each model has several built-in Module members that make up the neural network structure of the model.
These Module members will be initialized by the module_conf
given in your configuration.
There are a built-in dictionary named init_class_dict
and a built-in list named default_init_modules
in the base class.
init_class_dict
contains all the available initialization functions of the model parameters while default_init_modules
includes the network layers that have their own initialization functions.
__init__(self, args, device, model_conf, module_conf, criterion_conf)
-
Description:
In this initialization function, there are two parts of initialization: model-specific customized initialization and model-independent general initialization.-
Model-specific customized initialization is done by two interface functions:
module_init()
andcriterion_init()
.module_init()
initializes the neural network structure of the model whilecriterion_init()
initializes the criteria used to optimize (loss functions) and evaluate (validation metrics) the model. -
After the customized initialization, there are 3 steps for general initialization shared by all Model subclasses:
-
Pretrained parameters will be loaded into your model if the key
pretrained_model
is given. Multiple pretrained models can be specified and each of them can be loaded into different parts of your model. The mismatch between the names of pretrained parameters and the parameters of your model is handled by the keymapping
. The value of the keymapping
is a dictionary where each key-value item corresponds to a mapping of parameter names. The key is the parameter name in the pretrained parameters while the value is the parameter name of your model. -
If
pretrained_model
is not given, the parameters of your model will be initialized by the function that matches your input queryinit
. For more details about the available initialization functions, please refer to the built-in dictionaryinit_class_dict
. Ifinit
is not given, the default initialization functiontorch.nn.init.xavier_normal_
will be used to initialize your model. -
Finally, the specified parts of your model will be frozen if
frozen_modules
is given. If there is only one frozen module, you can directly give the string of its name tofrozen_modules
likefrozen_modules: {module_name}
. If there are multiple modules you want to freeze, you can give their names in a list as
Moreover, the frozen granularity depends on your input
frozen_modules
. For example,
1. If you givefrozen_modules: encoder_prenet
, all parameters of the prenet of your encoder will be frozen
2. If you givefrozen_modules: encoder_prenet.conv
, only the convolution layers of the prenet of your encoder will be frozen
3. If you givefrozen_modules: encoder_prenet.conv.0
, only the first convolution layer of the prenet of your encoder will be frozen
4. If you givefrozen_modules: encoder_prenet.conv.0.bias
, only the bias vector of the first convolution layer of the prenet of your encoder will be frozen -
-
Arguments:
- args: argparse.Namespace
Experiment pipeline arguments received from theRunner
object inrunner.py
. - device: torch.device
The computational device used for model calculation in the current GPU process. - model_conf: Dict
The model configuration used for general model initialization. - module_conf: Dict
The module configuration used for network structure initialization. - criterion_conf: Dict = None
The criterion configuration used for criterion (loss functions and evaluation metrics) initialization.
- args: argparse.Namespace
batch_to_cuda(self, data)
- Description:
The recursive function that transfers the batch data to the specified device in the current process. - Arguments:
- data: Dict or torch.Tensor
The input batch data. It should be either a Tensor or a Dict of Tensors. For the Dict input, the function itself will be called once by each Tensor element.
- data: Dict or torch.Tensor
- Return: Dict or torch.Tensor
If the input is a Dict, the returned output will also be a Dict of Tensors transferred to the target device;
If the input is a Tensor, the returned output will be its copy on the target device.
forward(self, batch_data, epoch, **kwargs)
-
Description:
The general model forward function shared by all the Model subclasses. This forward function has 3 steps:- preprocess and transfer the batch data to GPUs
- obtain the model prediction results
- calculate the loss function and evaluate the prediction results
For each step above, we provide interface functions for you to override and make your own implementation.
-
Arguments:
- batch_data: Dict
The input batch data received from thetrain
orvalid
dataloader object in the experimental pipeline.
The batch is in the form of a Dict where the key is the data name and the value is the data content. - epoch: int = None
The number of the current epoch. Used for real-time model visualization and model prediction. - **kwargs:
The additional arguments for real-time model visualization. If given, the code will go through the model visualization branch.
- batch_data: Dict
- Return:
In the training branch, the loss functions and evaluation metrics will be returned each of which is in the form of a Dict.
In the validation branch, only the evaluation metrics will be returned.
In the visualization branch, the model snapshots on the given validation instance will be returned.
aver_metrics_across_procs(self, metrics, batch_data)
- Description:
This function averages the evaluation metrics across all GPU processes in the DDP mode for model distribution. - Arguments:
- metrics: Dict[str, torch.Tensor]
The evaluation metrics to be averaged across all GPU processes. - batch_data: Dict
The input batch data used to calculate the batch size for averaging evaluation metrics.
- metrics: Dict[str, torch.Tensor]
- Return: Dict[str, torch.Tensor]
The evaluation metrics Dict after averaging. The key names remain the same.
evaluate(self, test_batch, infer_conf)
-
Description:
The shared evaluation function by all Model subclasses. This evaluation function has 2 steps:- preprocess and transfer the batch data to GPUs
- calculate the inference results
For each step above, we provide interface functions for you to override and make your own implementation.
-
Arguments:
- test_batch: Dict
The input batch data received from thetest
dataloader object in the experimental pipeline. - infer_conf: Dict
The configuration used for model inference.
- test_batch: Dict
- Return: Dict
A Dict of the inference results where each key-value item corresponds to one evaluation metric you want to save to the disk.
bad_cases_selection_init_fn()
- Description:
This hook function returns the default bad case selection method of each Model object. This default value will be referred by the Runner to present the top-N bad cases. The original hook implementation in the base Model class returns None which means no default value. - Return: List[List[str or int]]
The returned default value should be a list of tri-list where each tri-list is in the form of [selection_metric
,selection_mode
,case_number
].
For example, ['wer', 'max', 50] means 50 testing waveforms with the largest WER will be selected.
module_init(self, **kwargs)
- Description:
The interface function that initializes the Module members of the model. These Module members make up the neural network structure of the model. Some models have their customized part that also needs to be initialization in this function, e.g. the tokenizer of ASR and TTS models.
Note: This interface function must be overridden for each Model subclass. - Arguments:
- **kwargs:
The combination of the arguments in yourmodule_conf
andmodel_conf['customize_conf']
.
- **kwargs:
criterion_init(self, **criterion_conf)
- Description:
The interface function that initializes the Criterion members of the model. These Criterion members can be divided into two parts: the loss functions used for training and the evaluation metrics used for validation.
Note: This interface function must be overridden for each Model subclass. - Arguments:
- **criterion_conf:
The arguments in your givencriterion_conf
.
- **criterion_conf:
batch_preprocess_fn(self, batch_data)
- Description:
This hook function does the preprocessing for the input batch data before using them in
self.model_forward()
. This function is not mandatory to be overridden and the original implementation in the base Model class does nothing but return the inputbatch_data
.
Note: the key names in the returned Dict should match the argument names inself.model_forward()
. - Arguments:
- batch_data: Dict
Raw data of the input batch to be preprocessed in this hook function.
- batch_data: Dict
- Return: Dict
Processed data of the input batch that is ready to be used inself.model_forward()
.
module_forward(self, **batch_data)
- Description:
This interface function forwards the input batch data by all Module members.
Note:- This interface function must be overridden for each Model subclass.
- The argument names should match the key names in the returned Dict of
self.batch_preprocess_fn()
. - The key names in the returned Dict should match the argument names of
self.loss_calculation()
andself.metrics_calculation()
.
- Arguments:
- **batch_data:
Processed data of the input batch received fromself.batch_preprocess_fn()
.
- **batch_data:
- Return: Dict
Prediction results (logits) of the model on the input batch data. Some intermediate results (e.g., attention matrices) can also be returned for later use.
criterion_forward(self, **kwargs)
- Description:
This interface function is activated afterself.model_forward()
. It receives the model prediction results fromself.model_forward()
and input batch data fromself.batch_preprocess_fn()
.
Note: This interface function must be overridden for each Model subclass. - Arguments:
- **kwargs:
The combination of the returned arguments from
self.batch_preprocess_fn()
andself.model_forward()
.
- **kwargs:
The combination of the returned arguments from
- Return: (Dict[str, torch.Tensor], Dict[str, torch.Tensor]) or Dict[str, torch.Tensor]
The returned values should be different for the training and validation branches.- For training, two Dict[str, torch.Tensor] should be returned where the first one contains all the trainable training losses for optimization and the second one contains all the non-trainable evaluation metrics used to record the training status.
- For validation, only one Dict[str, torch.Tensor] should be returned which contains all the non-trainable evaluation metrics used to record the validation status.
visualize(self, epoch, sample_index, **valid_sample)
- Description:
- Arguments:
- Return:
inference(self, infer_conf, **kwargs)
-
Description:
This function receives the test data and test configuration. The inference results will be packaged into a Dict[str, Dict] which is passed to the TestMonitor object for disk storage. The returned Dict should be in the form ofThe first-level key is used to decide the name of the meta file as
Note: if the first-level key ends with .md, there will not be 'idx2' attached at the beginning of the file name.idx2{file_name}
. Its value is also a Dict and there must be two keys in this sub-Dict:format
andcontent
. The configuration of the sub-Dict is different for different file formats: 1. For pure text metadata files, the value offormat
must betxt
and the value ofcontent
must be a List of Python built-in data type (i.e.,. int, float, str, bool, ...). Each line of the fileidx2{file_name}
will be made up of the index of a test data instance and its metadata value in thecontent
List which are separated by a blank.
For example,dict(cer=dict(format='txt', content=[0.1, 0.2, 0.3]))
will create a pure text file namedidx2cer
which looks like
-
For audio files, the value of
format
must be eitherwav
orflac
and the value ofcontent
must be a List of array-like data type (e.g. numpy.ndarry, torch.Tensor, ...). Moreover, there must be an additional key namedsample_rate
to indicate the sampling rate of the waveforms to be saved in audio files. There will be a folder named{file_name}
that contains all the audio files and a pure text file namedidx2{file_name}
that contains the absolute paths of all the saved audio files.
For example,dict(wav=dict(format='flac', content=[np_arr1, np_arr2, np_arr3]))
will create a folder namedwav
and a pure text file namedidx2wav
in the same directory. The fileidx2wav
looks like:where{test_index1} /x/xx/wav/{test_index1}.flac {test_index2} /x/xx/wav/{test_index2}.flac {test_index3} /x/xx/wav/{test_index3}.flac
/x/xx/
is your result path given in yourexp_cfg
. -
For binary files, the value of
format
in the sub-Dict must benpy
and the value ofcontent
must be a List of numpy.ndarry (torch.Tensor is not supported). There will be a folder named{file_name}
that contains all the .npy files and a pure text file namedidx2{file_name}
that contains the absolute paths of all the saved binary files.
For example,dict(feat=dict(format='npy', content=[np_arr1, np_arr2, np_arr3]))
will create a folder namedfeat
and a pure text file namedidx2feat
. Theidx2feat
file is like:where{test_index1} /x/xx/feat/{test_index1}.npy {test_index2} /x/xx/feat/{test_index2}.npy {test_index3} /x/xx/feat/{test_index3}.npy
/x/xx/
is your result path given in yourexp_cfg
. - Arguments:
- infer_conf: Dict
The configuration Dict used for model inference. - **kwargs:
The testing data loaded fromtest
dataloader object in the experimental pipeline. - Return: Dict[str, Dict[str, str or List]]
The model inference results to be saved on the disk.
-
👆Back to the table of contents
Supported Models
- ASR Recipes
- asr.ARASR
- Structure: Auto-Regressive CTC-Attention ASR model.
- Input: One tuple of speech-text paired data (feat, feat_len, text, text_len) in model_forward().
- Output: One ASR loss calculated on the input data tuple in criterion_calculation().
- asr.SemiARASR
- Structure: Semi-supervised Auto-Regressive CTC-Attention ASR model.
- Input: Multiple tuples of speech-text paired data (feat, feat_len, text, text_len) in model_forward().
Each of them is generated by a specific
torch.utils.data.Dataloader
. - Output: Multiple ASR losses calculated on all the input data tuples in criterion_calculation(). A loss named loss is also returned which is the trainable overall loss summed by all ASR losses.
- asr.ARASR
- TTS Recipes
- tts.ARTTS
- Structure: Auto-Regressive Attention TTS model.
- Input: One tuple of speech-text paired data (feat, feat_len, text, text_len) in model_forward().
- Output: One TTS loss calculated on the input data tuple in criterion_calculation().
- tts.ARTTS
👆Back to the table of contents
How to Freeze a Specific Part of your Model
Parameter freezing can be done simply by giving the name of the module you want to freeze in frozen_modules. In the example below, the encoder of the ASR model will be frozen while other modules are still trainable.
If you want to freeze multiple modules, you can give their names as a list in frozen_modules. In the example below, the prenets of both the encoder and decoder will be frozen. The parameter freezing granularity can be very fine if you specify the module name by a series of dots. In the example below, the convolution layers of the prenet of the encoder will be frozen.👆Back to the table of contents
How to Initialize your Model by the Pretrained Model
Pretrained model loading can be easily done by giving the model path in pretrained_model. In the example below, the entire ASR model will be initialized by the given best_accuracy.pth model.
mdl_root: recipe/asr/librispeech/train-clean-100/exp/{exp_name}/models
model:
model_type: asr.ARASR
model_conf:
pretrained_model:
path: !ref <mdl_root>/accuracy_best.pth
model_root: recipe/asr/librispeech/train-clean-100/exp/{exp_name}/models
model:
model_type: asr.ARASR
model_conf:
pretrained_model:
path: !ref <model_root>/accuracy_best.pth
mapping:
encoder_prenet: encoder.prenet
encoder: encoder.encoder
Note that if there are overlapping modules between the mapping arguments of different pretrained models, the module will be initialized by the pretrained models at the back of the list.
model_root: recipe/asr/librispeech/train-clean-100/exp/{exp_name}/models
model:
model_type: asr.ARASR
model_conf:
pretrained_model:
- path: !ref <model_root>/accuracy_best.pth
mapping:
encoder_prenet: encoder.prenet
encoder: encoder.encoder
- path: !ref <model_root>/10_accuracy_average.pth
mapping:
decoder_prenet: decoder.prenet
decoder: decoder.decoder
decoder_postnet: decoder.postnet