Skip to content

bce_logits

Author: Sashi Novitasari Affiliation: NAIST Date: 2022.08

Author: Heli Qi Affiliation: NAIST Date: 2022.09

BCELogits

Bases: Criterion

Source code in speechain/criterion/bce_logits.py
class BCELogits(Criterion):
    """"""

    def criterion_init(self, pos_weight: float = 1.0, is_normalized: bool = True):
        """

        Args:
            pos_weight: float = 1.0
                The weight putted on stop points for stop loss calculation.
            is_normalized: bool = True
                Controls whether the sentence normalization is performed for stop loss calculation.

        """
        self.bce_loss = torch.nn.BCEWithLogitsLoss(
            reduction="none", pos_weight=torch.Tensor([pos_weight])
        )
        self.is_normalized = is_normalized

    def __call__(self, pred: torch.Tensor, tgt: torch.Tensor, tgt_len: torch.Tensor):
        """

        Args:
            pred: (batch, text_maxlen)
                The model predictions for the text
            tgt: (batch, text_maxlen)
                The target text labels.
            tgt_len: (batch,)
                The text lengths

        Returns:
            The cross entropy between logits and text

        """
        batch_size, feat_maxlen = pred.size(0), pred.size(1)
        if tgt.dtype == torch.bool:
            tgt = tgt.to(dtype=torch.float32)

        # mask production for the target labels
        tgt_mask = make_mask_from_len(tgt_len).squeeze()
        if tgt_len.is_cuda:
            tgt_mask = tgt_mask.cuda(tgt_len.device)

        # BCE loss calculation
        # make sure that the pos_weight is also on GPU
        if pred.is_cuda and not self.bce_loss.pos_weight.is_cuda:
            self.bce_loss.pos_weight = self.bce_loss.pos_weight.cuda(pred.device)
        # (batch_size, feat_maxlen)
        loss = self.bce_loss(pred, tgt)
        # (batch_size, feat_maxlen) -> (batch_size * feat_maxlen)
        loss = loss.reshape(-1).masked_fill(~tgt_mask.reshape(-1), 0.0)

        # loss reshaping
        if self.is_normalized:
            # (batch_size * feat_maxlen) -> (1,)
            loss = loss.sum() / tgt_mask.sum()
        else:
            # (batch_size * feat_maxlen) -> (batch_size, feat_maxlen)
            loss = loss.reshape(batch_size, feat_maxlen)
            # (batch_size, feat_maxlen) -> (batch_size,) -> (1,)
            loss = loss.sum(dim=-1).mean()

        return loss

__call__(pred, tgt, tgt_len)

Parameters:

Name Type Description Default
pred Tensor

(batch, text_maxlen) The model predictions for the text

required
tgt Tensor

(batch, text_maxlen) The target text labels.

required
tgt_len Tensor

(batch,) The text lengths

required

Returns:

Type Description

The cross entropy between logits and text

Source code in speechain/criterion/bce_logits.py
def __call__(self, pred: torch.Tensor, tgt: torch.Tensor, tgt_len: torch.Tensor):
    """

    Args:
        pred: (batch, text_maxlen)
            The model predictions for the text
        tgt: (batch, text_maxlen)
            The target text labels.
        tgt_len: (batch,)
            The text lengths

    Returns:
        The cross entropy between logits and text

    """
    batch_size, feat_maxlen = pred.size(0), pred.size(1)
    if tgt.dtype == torch.bool:
        tgt = tgt.to(dtype=torch.float32)

    # mask production for the target labels
    tgt_mask = make_mask_from_len(tgt_len).squeeze()
    if tgt_len.is_cuda:
        tgt_mask = tgt_mask.cuda(tgt_len.device)

    # BCE loss calculation
    # make sure that the pos_weight is also on GPU
    if pred.is_cuda and not self.bce_loss.pos_weight.is_cuda:
        self.bce_loss.pos_weight = self.bce_loss.pos_weight.cuda(pred.device)
    # (batch_size, feat_maxlen)
    loss = self.bce_loss(pred, tgt)
    # (batch_size, feat_maxlen) -> (batch_size * feat_maxlen)
    loss = loss.reshape(-1).masked_fill(~tgt_mask.reshape(-1), 0.0)

    # loss reshaping
    if self.is_normalized:
        # (batch_size * feat_maxlen) -> (1,)
        loss = loss.sum() / tgt_mask.sum()
    else:
        # (batch_size * feat_maxlen) -> (batch_size, feat_maxlen)
        loss = loss.reshape(batch_size, feat_maxlen)
        # (batch_size, feat_maxlen) -> (batch_size,) -> (1,)
        loss = loss.sum(dim=-1).mean()

    return loss

criterion_init(pos_weight=1.0, is_normalized=True)

Parameters:

Name Type Description Default
pos_weight float

float = 1.0 The weight putted on stop points for stop loss calculation.

1.0
is_normalized bool

bool = True Controls whether the sentence normalization is performed for stop loss calculation.

True
Source code in speechain/criterion/bce_logits.py
def criterion_init(self, pos_weight: float = 1.0, is_normalized: bool = True):
    """

    Args:
        pos_weight: float = 1.0
            The weight putted on stop points for stop loss calculation.
        is_normalized: bool = True
            Controls whether the sentence normalization is performed for stop loss calculation.

    """
    self.bce_loss = torch.nn.BCEWithLogitsLoss(
        reduction="none", pos_weight=torch.Tensor([pos_weight])
    )
    self.is_normalized = is_normalized