大部分名词已经在代码得注释部分说明,所有程序代码在本文中都以呈现,只需要调用main()函数即可:
通常main函数的主框架可以不动。
def main():
args = Arg()#参数类,其中定义了程序需要的参数
#CPU设置种子用于生成随机数,以使得结果是确定的
torch.manual_seed(args.seed)
# 加载数据
train_loader, test_loader = getData(args)
# 得到卷积神经网络类
model = CNNnet()
# 定义损失函数
loss_func = nn.CrossEntropyLoss()
# 定义优化方式
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
# 正式进入训练和测试
# 其中epoch表示便利训练集的次数
for epoch in range(1,args.epoch+1):
train(args, model, train_loader, opt, loss_func, epoch)
test(args, model, test_loader, loss_func, epoch)
由于本文直接使用了自带的数据集,不需要细化处理。
在使用自定义的数据集时可以继承Dataset类,使用自己的数据集。
def getData(args:Arg):
# 数据集的预处理:把所有数据变成tensor类型,自动归一化
# 当然还可以对数据集进行其他预处理
data_tf = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor() # 自动归一化[0.0,1.0]
]
)
# 如果本地没有数据集 要先执行该代码 获取数据集
# train_data = mnist.MNIST(Arg.data_path,train=True,transform=data_tf,download=True)
train_data = mnist.MNIST(args.data_path, train=True, transform=data_tf, download=False)
test_data = mnist.MNIST(args.data_path, train=False, transform=data_tf, download=False)
# 获取迭代数据:data.DataLoader(), 把训练集和测试集依次放进去
# Dataloader返回所有的数据,分成了许多批次,一个批次有batch_size大小的数据
train_loader = data.DataLoader(train_data, batch_size=args.batchSize, shuffle=True) # shuffle:是否打乱数据
test_loader = data.DataLoader(test_data, batch_size=args.batchSize, shuffle=True) # shuffle:是否打乱数据
return train_loader, test_loader
可以再该函数中修改自己的CNN网络或者变成其他的网络,
本文主要是对框架的学习,不必对代码进行细究。
class CNNnet(torch.nn.Module):
def __init__(self):
# 复制并使用CNNNet的父类的初始化方法,即先运行nn.Module的初始化函数
super(CNNnet, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1,
out_channels=16,
kernel_size=3,
stride=2,
padding=1),
torch.nn.BatchNorm2d(16),
torch.nn.ReLU()
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(16, 32, 3, 2, 1),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU()
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(32, 64, 3, 2, 1),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU()
)
self.conv4 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, 2, 2, 0),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU()
)
self.mlp1 = torch.nn.Linear(2 * 2 * 64, 100)
self.mlp2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.mlp1(x.view(x.size(0), -1))
x = self.mlp2(x)
return x
每个train的执行相当于一次epoch,遍历了一次数据集,train函数框架也大体可以不变。
def train(args, model, train_loader, opt, loss_func, epoch):
#主要是针对由于model在训练时和评价时 BatchNormalization和Dropout方法模式不同,
#训练时要带上model.train(),预测时要带上model.test()
model.train()
for batch_id,(data, target) in enumerate(train_loader):
data, target = Variable(data), Variable(target) # 初始输入,[128,1,28,28],[128]
output = model(data) # 最终输出[128,10]
loss = loss_func(output, target) # 计算损失
opt.zero_grad() # 清空参与更新参数值
loss.backward() # 反向传播
opt.step() # 参数更新
if batch_id%args.interval==0:
print("Train Epoch:{}[{}/{}({:.0f}%)]\tLoss:{:.6f}".format(epoch, batch_id*len(data),len(train_loader.dataset),\
100.*batch_id/len(train_loader), loss.item()))# 隔一段时间输出一下当前情况,可自定义
损失函数和优化器在train中的使用步骤:
1、获取损失: loss = loss_func(预测值, 真实值) #针对每个batch来说的
2、清空上一步参与更新参数:opt.zero_grad()
3、误差反向传播:loss.backward()
4、更新参数:opt.step()
test的框架一般只需要自定义一下自己的评价指标即可
def test(args, model, test_loader, loss_func):
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:# 每批次每批次的输入
data, target = Variable(data), Variable(target) # 初始输入,[128,1,28,28],[128]
output = model(data) # 最终输出[128,10]
test_loss += loss_func(output, target).item()# 损失总和
isRight = torch.max(output, 1)[1].numpy() == target.numpy()
correct+=np.sum(isRight!= 0) #所有的正确率
accuracy = correct/len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: ({:.1f}%)\n'.format(test_loss, accuracy*100))
不包含全部的参数,CNN中有许多参数,因为方便适用于所有网络,本文没有将CNN的参数放进去。
class Arg:
def __init__(self):
self.batchSize = 64 # 批次大小
self.data_path = '../data/' # 路径自己设定
self.lr = 0.001 #学习率
self.epoch = 20
self.interval = 100 #多少批次的间隔后输出一下当前训练结果
# 其他
self.seed = 1 #随机种子,设置种子用于生成随机数,以使得结果是确定的
import torch
from torch.utils import data # 获取迭代数据
from torch.autograd import Variable # 获取变量
import torchvision
from torchvision.datasets import mnist # 获取数据集
import torch.nn as nn
import numpy as np
不足之处,请多指正!
因篇幅问题不能全部显示,请点此查看更多更全内容