Skip to content



This class implements the CTC prefix scorer of Algorithm 2 in reference: Official implementation: Arguments

x : torch.Tensor The encoder states. enc_lens : torch.Tensor The actual length of each enc_states sequence. batch_size : int The size of the batch. beam_size : int The width of beam. blank_index : int The index of the blank token. eos_index : int The index of the end-of-sequence (eos) token. ctc_window_size: int Compute the ctc scores over the time frames using windowing based on attention peaks. If 0, no windowing applied.

Source code in speechain/infer_func/
class CTCPrefixScorer:
    """This class implements the CTC prefix scorer of Algorithm 2 in
    Official implementation:
    x : torch.Tensor
        The encoder states.
    enc_lens : torch.Tensor
        The actual length of each enc_states sequence.
    batch_size : int
        The size of the batch.
    beam_size : int
        The width of beam.
    blank_index : int
        The index of the blank token.
    eos_index : int
        The index of the end-of-sequence (eos) token.
    ctc_window_size: int
        Compute the ctc scores over the time frames using windowing based on attention peaks.
        If 0, no windowing applied.

    def __init__(
        self.blank_index = blank_index
        self.eos_index = eos_index
        self.max_enc_len = x.size(1)
        self.batch_size = batch_size
        self.beam_size = beam_size
        self.vocab_size = x.size(-1)
        self.device = x.device
        self.minus_inf = -1e20
        self.last_frame_index = enc_lens - 1

        # mask frames > enc_lens
        mask = (
            ~make_mask_from_len(enc_lens, return_3d=False)
            .expand(-1, -1, self.vocab_size)
        x.masked_fill_(mask, self.minus_inf)
        x[:, :, self.blank_index] = x[:, :, self.blank_index].masked_fill_(
            mask[:, :, self.blank_index], 0

        # dim=0: xnb, nonblank posteriors, dim=1: xb, blank posteriors
        xnb = x.transpose(0, 1)
        xb = xnb[:, :, self.blank_index].unsqueeze(2).expand(-1, -1, self.vocab_size)

        # (2, L, batch_size * beam_size, vocab_size)
        self.x = torch.stack([xnb, xb])

        # The first index of each sentence.
        self.beam_offset = torch.arange(batch_size, device=self.device) * self.beam_size
        # The first index of each candidates.
        self.cand_offset = (
            torch.arange(batch_size, device=self.device) * self.vocab_size

    def forward_step(self, g, state):
        """This method if one step of forwarding operation for the prefix ctc scorer.

        g : torch.Tensor
            The tensor of prefix label sequences, h = g + c.
        state : tuple
            Previous ctc states.

        prefix_length = g.size(1)
        last_char = [gi[-1] for gi in g] if prefix_length > 0 else [0] * len(g)
        if state is None:
            # r_prev: (L, 2, batch_size * beam_size)
            r_prev = torch.full(
                (self.max_enc_len, 2, self.batch_size * self.beam_size),

            # r_prev[:, 0] = r^n(g), r_prev[:, 1] = r^b(g)
            # Accumulate blank posteriors at each step
            r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank_index], dim=0)
            psi_prev = 0.0
            r_prev, psi_prev = state

        # Prepare forward probs
        # r[:, 0] = r^n(h), r[:, 1] = r^b(h)
        r = torch.full(
                self.batch_size * self.beam_size,

        # (Alg.2-6)
        if prefix_length == 0:
            r[0, 0] = self.x[0, 0]
        # (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g)
        r_sum = torch.logsumexp(r_prev, 1)
        phi = r_sum.unsqueeze(2).repeat(1, 1, self.vocab_size)

        # (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0
        for i in range(self.batch_size * self.beam_size):
            phi[:, i, last_char[i]] = r_prev[:, 1, i]

        # Start, end frames for scoring (|g| < |h|).
        # Scoring based on attn peak
        start = max(1, prefix_length)
        end = self.max_enc_len

        # Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)):
        for t in range(start, end):
            # (Alg.2-11): dim=0, p(h|cur step is nonblank) = [p(prev step=y) + phi] * p(c)
            rnb_prev = r[t - 1, 0]
            # (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank)
            rb_prev = r[t - 1, 1]
            r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view(
                2, 2, self.batch_size * self.beam_size, self.vocab_size
            r[t] = torch.logsumexp(r_, 1) + self.x[:, t]

        # Compute the predix prob, psi
        psi_init = r[start - 1, 0].unsqueeze(0)
        # phi is prob at t-1 step, shift one frame and add it to the current prob p(c)
        phix =[0].unsqueeze(0), phi[:-1]), dim=0) + self.x[0]
        # (Alg.2-13): psi = psi + phi * p(c)
        psi = torch.logsumexp([start:end], psi_init), dim=0), dim=0)

        # (Alg.2-3): if c = <eos>, psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames
        for i in range(self.batch_size * self.beam_size):
            psi[i, self.eos_index] = r_sum[
                self.last_frame_index[i // self.beam_size], i

        # Exclude blank probs for joint scoring
        psi[:, self.blank_index] = self.minus_inf

        return psi - psi_prev, (r, psi)

    def permute_mem(self, memory, beam_idx, token_idx):
        """This method permutes the CTC model memory to synchronize the memory index
        with the current output.

        memory : No limit
            The memory variable to be permuted.
        index : torch.Tensor
            The index of the previous path.
        The variable of the memory being permuted.
        r, psi = memory
        r, psi = r[:, :, beam_idx], psi[beam_idx]
        # The index of top-K vocab came from in (t-1) timesteps.
        best_index = (
            + (self.beam_offset.unsqueeze(1).expand_as(token_idx) * self.vocab_size)
        # synchronize forward prob
        psi = torch.index_select(psi.view(-1), dim=0, index=best_index)
        psi = (
            psi.view(-1, 1)
            .repeat(1, self.vocab_size)
            .view(self.batch_size * self.beam_size, self.vocab_size)

        # synchronize ctc states
        r = torch.index_select(
            r.view(-1, 2, self.batch_size * self.beam_size * self.vocab_size),
        r = r.view(-1, 2, self.batch_size * self.beam_size)

        return r, psi

forward_step(g, state)

This method if one step of forwarding operation for the prefix ctc scorer.


g : torch.Tensor The tensor of prefix label sequences, h = g + c. state : tuple Previous ctc states.

Source code in speechain/infer_func/
def forward_step(self, g, state):
    """This method if one step of forwarding operation for the prefix ctc scorer.

    g : torch.Tensor
        The tensor of prefix label sequences, h = g + c.
    state : tuple
        Previous ctc states.

    prefix_length = g.size(1)
    last_char = [gi[-1] for gi in g] if prefix_length > 0 else [0] * len(g)
    if state is None:
        # r_prev: (L, 2, batch_size * beam_size)
        r_prev = torch.full(
            (self.max_enc_len, 2, self.batch_size * self.beam_size),

        # r_prev[:, 0] = r^n(g), r_prev[:, 1] = r^b(g)
        # Accumulate blank posteriors at each step
        r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank_index], dim=0)
        psi_prev = 0.0
        r_prev, psi_prev = state

    # Prepare forward probs
    # r[:, 0] = r^n(h), r[:, 1] = r^b(h)
    r = torch.full(
            self.batch_size * self.beam_size,

    # (Alg.2-6)
    if prefix_length == 0:
        r[0, 0] = self.x[0, 0]
    # (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g)
    r_sum = torch.logsumexp(r_prev, 1)
    phi = r_sum.unsqueeze(2).repeat(1, 1, self.vocab_size)

    # (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0
    for i in range(self.batch_size * self.beam_size):
        phi[:, i, last_char[i]] = r_prev[:, 1, i]

    # Start, end frames for scoring (|g| < |h|).
    # Scoring based on attn peak
    start = max(1, prefix_length)
    end = self.max_enc_len

    # Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)):
    for t in range(start, end):
        # (Alg.2-11): dim=0, p(h|cur step is nonblank) = [p(prev step=y) + phi] * p(c)
        rnb_prev = r[t - 1, 0]
        # (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank)
        rb_prev = r[t - 1, 1]
        r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view(
            2, 2, self.batch_size * self.beam_size, self.vocab_size
        r[t] = torch.logsumexp(r_, 1) + self.x[:, t]

    # Compute the predix prob, psi
    psi_init = r[start - 1, 0].unsqueeze(0)
    # phi is prob at t-1 step, shift one frame and add it to the current prob p(c)
    phix =[0].unsqueeze(0), phi[:-1]), dim=0) + self.x[0]
    # (Alg.2-13): psi = psi + phi * p(c)
    psi = torch.logsumexp([start:end], psi_init), dim=0), dim=0)

    # (Alg.2-3): if c = <eos>, psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames
    for i in range(self.batch_size * self.beam_size):
        psi[i, self.eos_index] = r_sum[
            self.last_frame_index[i // self.beam_size], i

    # Exclude blank probs for joint scoring
    psi[:, self.blank_index] = self.minus_inf

    return psi - psi_prev, (r, psi)

permute_mem(memory, beam_idx, token_idx)

This method permutes the CTC model memory to synchronize the memory index with the current output.


memory : No limit The memory variable to be permuted. index : torch.Tensor The index of the previous path. Return

The variable of the memory being permuted.

Source code in speechain/infer_func/
def permute_mem(self, memory, beam_idx, token_idx):
    """This method permutes the CTC model memory to synchronize the memory index
    with the current output.

    memory : No limit
        The memory variable to be permuted.
    index : torch.Tensor
        The index of the previous path.
    The variable of the memory being permuted.
    r, psi = memory
    r, psi = r[:, :, beam_idx], psi[beam_idx]
    # The index of top-K vocab came from in (t-1) timesteps.
    best_index = (
        + (self.beam_offset.unsqueeze(1).expand_as(token_idx) * self.vocab_size)
    # synchronize forward prob
    psi = torch.index_select(psi.view(-1), dim=0, index=best_index)
    psi = (
        psi.view(-1, 1)
        .repeat(1, self.vocab_size)
        .view(self.batch_size * self.beam_size, self.vocab_size)

    # synchronize ctc states
    r = torch.index_select(
        r.view(-1, 2, self.batch_size * self.beam_size * self.vocab_size),
    r = r.view(-1, 2, self.batch_size * self.beam_size)

    return r, psi