Skip to content

error_rate

Author: Heli Qi Affiliation: NAIST Date: 2022.07

ErrorRate

Bases: Criterion

Source code in speechain/criterion/error_rate.py
class ErrorRate(Criterion):
    """"""

    def criterion_init(self, tokenizer: Tokenizer = None, do_aver: bool = False):
        """

        Args:
            tokenizer: Tokenizer
            do_aver: bool

        """
        self.tokenizer = tokenizer
        self.do_aver = do_aver

    def __call__(
        self,
        hypo_text: torch.Tensor or List[str] or str,
        real_text: torch.Tensor or List[str] or str,
        tokenizer: Tokenizer = None,
        do_aver: bool = False,
    ):
        """

        Args:
            hypo_text (torch.Tensor or List[str] or str): the hypothesis text
            real_text (torch.Tensor or List[str] or str): the reference text
            tokenizer (Tokenizer): the tokenizer
            do_aver (bool): whether to average the error rate over the batch

        Returns:

        """
        if tokenizer is None:
            assert self.tokenizer is not None
            tokenizer = self.tokenizer

        # make sure that hypo_text is a 2-dim tensor or a list of strings
        if isinstance(hypo_text, torch.Tensor) and hypo_text.dim() == 1:
            hypo_text = hypo_text.unsqueeze(0)
        elif isinstance(hypo_text, str):
            hypo_text = [hypo_text]
        # make sure that real_text is a 2-dim tensor or a list of strings
        if isinstance(real_text, torch.Tensor) and real_text.dim() == 1:
            real_text = real_text.unsqueeze(0)
        elif isinstance(real_text, str):
            real_text = [real_text]

        cer_dist, cer_len, wer_dist, wer_len = [], [], [], []
        for i in range(len(hypo_text)):
            # obtain the strings
            hypo_string = text_preprocess(hypo_text[i], tokenizer)
            real_string = text_preprocess(real_text[i], tokenizer)

            # calculate CER
            hypo_chars = hypo_string.replace(" ", "")
            real_chars = real_string.replace(" ", "")
            cer_dist.append(editdistance.eval(hypo_chars, real_chars))
            cer_len.append(len(real_chars))

            # calculate WER
            # Note that split(" ") is not equivalent to split() here
            # because split(" ") will give an extra '' at the end of the list if the string ends with a " "
            # while split() doesn't
            hypo_words = hypo_string.split()
            real_words = real_string.split()
            wer_dist.append(editdistance.eval(hypo_words, real_words))
            wer_len.append(len(real_words))

        cer, wer = [], []
        for i in range(len(cer_dist)):
            cer.append(cer_dist[i] / cer_len[i])
            wer.append(wer_dist[i] / wer_len[i])
        if do_aver:
            cer = sum(cer) / len(cer)
            wer = sum(wer) / len(wer)

        return cer, wer

__call__(hypo_text, real_text, tokenizer=None, do_aver=False)

Parameters:

Name Type Description Default
hypo_text Tensor or List[str] or str

the hypothesis text

required
real_text Tensor or List[str] or str

the reference text

required
tokenizer Tokenizer

the tokenizer

None
do_aver bool

whether to average the error rate over the batch

False

Returns:

Source code in speechain/criterion/error_rate.py
def __call__(
    self,
    hypo_text: torch.Tensor or List[str] or str,
    real_text: torch.Tensor or List[str] or str,
    tokenizer: Tokenizer = None,
    do_aver: bool = False,
):
    """

    Args:
        hypo_text (torch.Tensor or List[str] or str): the hypothesis text
        real_text (torch.Tensor or List[str] or str): the reference text
        tokenizer (Tokenizer): the tokenizer
        do_aver (bool): whether to average the error rate over the batch

    Returns:

    """
    if tokenizer is None:
        assert self.tokenizer is not None
        tokenizer = self.tokenizer

    # make sure that hypo_text is a 2-dim tensor or a list of strings
    if isinstance(hypo_text, torch.Tensor) and hypo_text.dim() == 1:
        hypo_text = hypo_text.unsqueeze(0)
    elif isinstance(hypo_text, str):
        hypo_text = [hypo_text]
    # make sure that real_text is a 2-dim tensor or a list of strings
    if isinstance(real_text, torch.Tensor) and real_text.dim() == 1:
        real_text = real_text.unsqueeze(0)
    elif isinstance(real_text, str):
        real_text = [real_text]

    cer_dist, cer_len, wer_dist, wer_len = [], [], [], []
    for i in range(len(hypo_text)):
        # obtain the strings
        hypo_string = text_preprocess(hypo_text[i], tokenizer)
        real_string = text_preprocess(real_text[i], tokenizer)

        # calculate CER
        hypo_chars = hypo_string.replace(" ", "")
        real_chars = real_string.replace(" ", "")
        cer_dist.append(editdistance.eval(hypo_chars, real_chars))
        cer_len.append(len(real_chars))

        # calculate WER
        # Note that split(" ") is not equivalent to split() here
        # because split(" ") will give an extra '' at the end of the list if the string ends with a " "
        # while split() doesn't
        hypo_words = hypo_string.split()
        real_words = real_string.split()
        wer_dist.append(editdistance.eval(hypo_words, real_words))
        wer_len.append(len(real_words))

    cer, wer = [], []
    for i in range(len(cer_dist)):
        cer.append(cer_dist[i] / cer_len[i])
        wer.append(wer_dist[i] / wer_len[i])
    if do_aver:
        cer = sum(cer) / len(cer)
        wer = sum(wer) / len(wer)

    return cer, wer

criterion_init(tokenizer=None, do_aver=False)

Parameters:

Name Type Description Default
tokenizer Tokenizer

Tokenizer

None
do_aver bool

bool

False
Source code in speechain/criterion/error_rate.py
def criterion_init(self, tokenizer: Tokenizer = None, do_aver: bool = False):
    """

    Args:
        tokenizer: Tokenizer
        do_aver: bool

    """
    self.tokenizer = tokenizer
    self.do_aver = do_aver