分享

PolyLoss:一种将分类损失函数加入泰勒展开式的损失函数

 520jefferson 2022-05-12

大家好,我是刘聪NLP。

前两天实习生给我推了一篇针对损失函数进行优化的论文,一种将分类损失函数加入泰勒展开式的损失函数-PolyLoss,全名《POLYLOSS: A POLYNOMIAL EXPANSION PERSPECTIVE OF CLASSIFICATION LOSS FUNCTIONS》。由于该篇论文是在图像任务上进行实验的,抱着试一试的心态,在NLP的AFQMC数据上进行了实验,发现是有提升的,因此分享给大家。

paper:https:///pdf/2204.12511.pdf

POLYLOSS

原理和公式推导我就不过多介绍了,想了解的同学可以自己看一下论文。该篇论文发现,其实仅增加一个多项式系数就相比于原始的Cross-Entropy Loss和Focal Loss在多种图像任务上有所提高。并且论文中提供了TF的相关代码,详细如下:

import tensorflow as tf


def cross_entropy_tf(logits, labels, class_number):
    '''TF交叉熵损失函数'''
    labels = tf.one_hot(labels, class_number)
    ce_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
    return ce_loss


def poly1_cross_entropy_tf(logits, labels, class_number, epsilon=1.0):
    '''poly_loss针对交叉熵损失函数优化,使用增加第一个多项式系数'''
    labels = tf.one_hot(labels, class_number)
    ce_loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
    poly1 = tf.reduce_sum(labels * tf.nn.softmax(logits), axis=-1)
    poly1_loss = ce_loss + epsilon * (1 - poly1)
    return poly1_loss


def focal_loss_tf(logits, labels, class_number, alpha=0.25, gamma=2.0, epsilon=1.e-7):
    '''TF focal_loss函数'''
    alpha = tf.constant(alpha, dtype=tf.float32)
    y_true = tf.one_hot(0, class_number)
    alpha = y_true * alpha + (tf.ones_like(y_true) - y_true) * (1 - alpha)
    labels = tf.cast(labels, dtype=tf.int32)
    logits = tf.cast(logits, tf.float32)
    softmax = tf.reshape(tf.nn.softmax(logits), [-1])
    labels_shift = tf.range(0, logits.shape[0]) * logits.shape[1] + labels
    prob = tf.gather(softmax, labels_shift)
    prob = tf.clip_by_value(prob, epsilon, 1. - epsilon)
    alpha_choice = tf.gather(alpha, labels)
    weight = tf.pow(tf.subtract(1., prob), gamma)
    weight = tf.multiply(alpha_choice, weight)
    fc_loss = -tf.multiply(weight, tf.log(prob))
    return fc_loss


def poly1_focal_loss_tf(logits, labels, class_number=3, alpha=0.25, gamma=2.0, epsilon=1.0):
    fc_loss = focal_loss_tf(logits, labels, class_number, alpha, gamma)
    p = tf.math.sigmoid(logits)
    labels = tf.one_hot(labels, class_number)
    poly1 = labels * p + (1 - labels) * (1 - p)
    poly1_loss = fc_loss + tf.reduce_mean(epsilon * tf.math.pow(1 - poly1, 2 + 1), axis=-1)
    return poly1_loss


if __name__ == '__main__':
    logits = [[20.51],
              [0.113]]
    labels = [12]

    print('TF loss result:')
    ce_loss = cross_entropy_tf(logits, labels, class_number=3)
    with tf.Session() as sess:
        print('tf cross_entropy:', sess.run(ce_loss))

    poly1_ce_loss = poly1_cross_entropy_tf(logits, labels, class_number=3, epsilon=1.0)
    with tf.Session() as sess:
        print('tf poly1_cross_entropy:', sess.run(poly1_ce_loss))

    fc_loss = focal_loss_tf(logits, labels, class_number=3, alpha=0.25, gamma=2.0, epsilon=1.e-7)
    with tf.Session() as sess:
        print('tf focal_loss:', sess.run(fc_loss))

    poly1_fc_loss = poly1_focal_loss_tf(logits, labels, class_number=3, alpha=0.25, gamma=2.0, epsilon=1.0)
    with tf.Session() as sess:
        print('tf poly1_focal_loss:', sess.run(poly1_fc_loss))

结果如下:

TF loss result:
tf cross_entropy: [1.9643688  0.17425454]
tf poly1_cross_entropy: [2.8241243  0.33417147]
tf focal_loss: [1.0890163  0.00334221]
tf poly1_focal_loss: [1.4649665 0.1818437]

笔者现在用Torch较多,因此提供了Torch相关的代码,

import torch
import torch.nn as nn
import torch.nn.functional as F


def poly1_cross_entropy_torch(logits, labels, class_number=3, epsilon=1.0):
    poly1 = torch.sum(F.one_hot(labels, class_number).float() * F.softmax(logits), dim=-1)
    ce_loss = F.cross_entropy(torch.tensor(logits), torch.tensor(labels), reduction='none')
    poly1_ce_loss = ce_loss + epsilon * (1 - poly1)
    return poly1_ce_loss


class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, num_classes=3):
        super(FocalLoss, self).__init__()
        self.alpha = torch.zeros(num_classes)
        self.alpha[0] += alpha
        self.alpha[1:] += (1 - alpha)
        self.gamma = gamma

    def forward(self, logits, labels):
        logits = logits.view(-1, logits.size(-1))
        self.alpha = self.alpha.to(logits.device)
        logits_logsoft = F.log_softmax(logits, dim=1)
        logits_softmax = torch.exp(logits_logsoft)
        logits_softmax = logits_softmax.gather(1, labels.view(-11))
        logits_logsoft = logits_logsoft.gather(1, labels.view(-11))
        self.alpha = self.alpha.gather(0, labels.view(-1))
        loss = -torch.mul(torch.pow((1 - logits_softmax), self.gamma), logits_logsoft)
        loss = torch.mul(self.alpha, loss.t())[0, :]
        return loss


def poly1_focal_loss_torch(logits, labels, alpha=0.25, gamma=2, num_classes=3, epsilon=1.0):
    focal_loss_func = FocalLoss(alpha, gamma, num_classes)
    focal_loss = focal_loss_func(logits, labels)

    p = torch.nn.functional.sigmoid(logits)
    labels = torch.nn.functional.one_hot(labels, num_classes)
    labels = torch.tensor(labels, dtype=torch.float32)
    poly1 = labels * p + (1 - labels) * (1 - p)
    poly1_focal_loss = focal_loss + torch.mean(epsilon * torch.pow(1 - poly1, 2 + 1), dim=-1)
    return poly1_focal_loss


if __name__ == '__main__':
    logits = [[20.51],
              [0.113]]
    labels = [12]
    print('PyTorch loss result:')
    ce_loss = F.cross_entropy(torch.tensor(logits), torch.tensor(labels), reduction='none')
    print('torch cross_entropy:', ce_loss)

    poly1_ce_loss = poly1_cross_entropy_torch(torch.tensor(logits), torch.tensor(labels), class_number=3, epsilon=1.0)
    print('torch poly1_cross_entropy:', poly1_ce_loss)

    focal_loss_func = FocalLoss(alpha=0.25, gamma=2, num_classes=3)
    fc_loss = focal_loss_func(torch.tensor(logits), torch.tensor(labels))
    print('torch focal_loss:', fc_loss)

    poly1_fc_loss = poly1_focal_loss_torch(torch.tensor(logits), torch.tensor(labels), alpha=0.25, gamma=2,
                                           num_classes=3, epsilon=1.0)
    print('torch poly1_focal_loss:', poly1_fc_loss)

结果如下:

PyTorch loss result:
torch cross_entropy: tensor([1.9644, 0.1743])
torch poly1_cross_entropy: tensor([2.8241, 0.3342])
torch focal_loss: tensor([1.0890, 0.0033])
torch poly1_focal_loss: tensor([1.4650, 0.1818])

结果

笔者没有做太多的NLP实验,仅在AFQMC数据和自己公司的匹配数据上进行了实验,均有提高。在AFQMC数据上采用cross_entropy和poly1_cross_entropy进行了对比实验,在dev上的效果分别是0.7388和0.7435。

如果有时间和资源的同学,可以尝试更多,欢迎补充。

总结

反正实现又不难,代码改动又不大,为什么不试试呢?说不定有奇效~~~

这算不算水了一篇呢~~~

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多