Skip to content

pos_enc

RelPositionalEncoding

Bases: Module

Source code in speechain/module/conformer/pos_enc.py
class RelPositionalEncoding(Module):
    """"""

    def module_init(
        self,
        posenc_type: str = "mix",
        d_model: int = 512,
        emb_scale: bool = False,
        max_len: int = 5000,
        dropout: float = 0.0,
    ):

        assert posenc_type in [
            "mix",
            "sep",
        ], f"The type of PositionalEncoding layer must be either 'mix' or 'sep', but got type={posenc_type}!"
        assert (
            d_model % 2 == 0
        ), f"Cannot apply sin/cos positional encoding to the vectors with odd dimensions (got d_model={d_model:d})."

        self.posenc_type = posenc_type
        self.d_model = d_model
        self.emb_scale = emb_scale

        # positional encoding matrix
        self.max_len = max_len
        self.update_posenc(max_len)

        # positional encoding Dropout layer
        self.dropout = torch.nn.Dropout(p=dropout)

    def update_posenc(self, seq_len: int = None):
        """

        Args:
            max_len:

        """
        if seq_len is None:
            seq_len = self.max_len

        if seq_len >= self.max_len:
            # positional encoding calculation
            position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(
                torch.arange(0, self.d_model, 2, dtype=torch.float)
                * (math.log(10000.0) / self.d_model)
            )
            posenc_past, posenc_future = torch.zeros(
                seq_len, self.d_model
            ), torch.zeros(seq_len, self.d_model)

            # 'mix' positional encoding: sine functions and cosine functions mix up with each other
            if self.posenc_type == "mix":
                posenc_past[:, 0::2] = torch.sin(position / div_term)
                posenc_past[:, 1::2] = torch.cos(position / div_term)

                posenc_future[:, 0::2] = torch.sin(-1 * position / div_term)
                posenc_future[:, 1::2] = torch.cos(-1 * position / div_term)
            # 'sep' positional encoding: sine functions and cosine functions occupy the positional encoding separately
            elif self.posenc_type == "sep":
                div_term_ext = torch.exp(
                    torch.arange(self.d_model, self.d_model * 2, 2, dtype=torch.float)
                    * (math.log(10000.0) / self.d_model)
                )
                posenc_past[:, : int(self.d_model / 2)] = torch.sin(position / div_term)
                posenc_past[:, int(self.d_model / 2) :] = torch.cos(
                    position / div_term_ext
                )

                posenc_future[:, : int(self.d_model / 2)] = torch.sin(
                    -1 * position / div_term
                )
                posenc_future[:, int(self.d_model / 2) :] = torch.cos(
                    -1 * position / div_term_ext
                )

            # posenc = posenc.unsqueeze(0) does not put posenc into the buffer
            # here register_buffer() allows posenc to be automatically put onto GPUs as a buffer member
            self.register_buffer("posenc_past", posenc_past.unsqueeze(0))
            self.register_buffer("posenc_future", posenc_future.unsqueeze(0))

    def forward(self, emb_feat: torch.Tensor):

        # update the buffered positional encodings by sequence length
        seq_len = emb_feat.size(1)
        self.update_posenc(seq_len)

        with torch.no_grad():
            posenc_past = torch.flip(self.posenc_past[:, :seq_len], (1,))
            posenc_future = self.posenc_future[:, 1:seq_len]
            # (1, 2 * seq_len - 1, embed_dim)
            posenc = torch.cat([posenc_past, posenc_future], dim=1)

        # (optional) scale the embedded feature up by sqrt(d_model)
        if self.emb_scale:
            emb_feat *= math.sqrt(self.d_model)

        # apply dropout to both prenet embedded feature and positional encoding
        return self.dropout(emb_feat), self.dropout(posenc)

update_posenc(seq_len=None)

Parameters:

Name Type Description Default
max_len
required
Source code in speechain/module/conformer/pos_enc.py
def update_posenc(self, seq_len: int = None):
    """

    Args:
        max_len:

    """
    if seq_len is None:
        seq_len = self.max_len

    if seq_len >= self.max_len:
        # positional encoding calculation
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, dtype=torch.float)
            * (math.log(10000.0) / self.d_model)
        )
        posenc_past, posenc_future = torch.zeros(
            seq_len, self.d_model
        ), torch.zeros(seq_len, self.d_model)

        # 'mix' positional encoding: sine functions and cosine functions mix up with each other
        if self.posenc_type == "mix":
            posenc_past[:, 0::2] = torch.sin(position / div_term)
            posenc_past[:, 1::2] = torch.cos(position / div_term)

            posenc_future[:, 0::2] = torch.sin(-1 * position / div_term)
            posenc_future[:, 1::2] = torch.cos(-1 * position / div_term)
        # 'sep' positional encoding: sine functions and cosine functions occupy the positional encoding separately
        elif self.posenc_type == "sep":
            div_term_ext = torch.exp(
                torch.arange(self.d_model, self.d_model * 2, 2, dtype=torch.float)
                * (math.log(10000.0) / self.d_model)
            )
            posenc_past[:, : int(self.d_model / 2)] = torch.sin(position / div_term)
            posenc_past[:, int(self.d_model / 2) :] = torch.cos(
                position / div_term_ext
            )

            posenc_future[:, : int(self.d_model / 2)] = torch.sin(
                -1 * position / div_term
            )
            posenc_future[:, int(self.d_model / 2) :] = torch.cos(
                -1 * position / div_term_ext
            )

        # posenc = posenc.unsqueeze(0) does not put posenc into the buffer
        # here register_buffer() allows posenc to be automatically put onto GPUs as a buffer member
        self.register_buffer("posenc_past", posenc_past.unsqueeze(0))
        self.register_buffer("posenc_future", posenc_future.unsqueeze(0))