搜索
您的当前位置:首页正文

pytorch系列12 --pytorch自定义损失函数custom loss function

来源:吉趣旅游网

本文主要内容:

1. 关于nn.Module与nn.Functional的区别:

简答的说就是, nn.Module是一个包装好的类,具体定义了一个网络层,可以维护状态和存储参数信息;而nn.Functional仅仅提供了一个计算,不会维护状态信息和存储参数。

对于activation函数,比如(relu, sigmoid等),dropout,pooling等没有训练参数,可以使用functional模块。

2. 自定义损失函数

前面讲过,只要Tensor算数操作(+, -,*, %,求导等)中,有一个Tesor
resquire_grad=True,则该操作得到的Tensor具有反向传播,自动求导的功能。

因而只要自己实现的loss使用tensor提供的math operation就可以。

所以第一种自定义loss函数的方法就是使用tensor的math operation实现loss定义

1. 继承于nn.Module

在forward中实现loss定义,注意:

  • 所有的数学操作使用tensor提供的math operation
  • 返回的tensor是0-dim的scalar
  • 有可能会用到nn.functional中的一些操作

自定义MSEloss实现:

class My_loss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x, y):
        return torch.mean(torch.pow((x - y), 2))

使用:

criterion = My_loss()

loss = criterion(outputs, targets)
2. 自定义函数

看一自定义类中,其实最终调用还是forward实现,同时nn.Module还要维护一些其他变量和状态。不如直接自定义loss函数实现:


# 2. 直接定义函数 , 不需要维护参数,梯度等信息
# 注意所有的数学操作需要使用tensor完成。
def my_mse_loss(x, y):
    return torch.mean(torch.pow

因篇幅问题不能全部显示,请点此查看更多更全内容

Top