当前位置: 首页 > news >正文

基于pytorch搭建CNN

先上代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlibmatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',  # matplotlib渲染数学字体时使用的字体,和Times New Roman差别不大"font.serif": ['SimSun'],  # 宋体'axes.unicode_minus': False  # 处理负号,即-号
}
matplotlib.rcParams.update(config)# 定义超参数
input_size = 28  # 图像的尺寸为28*28*1
num_classes = 10  # 一共有10个类别的结果
num_epochs = 3
batch_size = 64  # 一个批次训练64张图片
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)class CNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(  # 输入大小(1,28,28)nn.Conv2d(in_channels=1,  # 灰度图out_channels=16,  # 得到的特征图的个数(也是使用卷积核的个数)kernel_size=5,  # 卷积核的大小stride=1,  # 步长padding=2,  # 边缘填充,如果希望得到的特征图大小和原来一样,那么padding=(kernel_size-1)/2 if stride = 1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 池化层操作区域(2*2),输出结果为(16,14,14))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 输出(32,7,7))self.out = nn.Linear(32 * 7 * 7, 10)  # 全连接层def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)output = self.out(x)return output# 计算准确率
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1]rights = pred.eq(labels.data.view_as(pred)).sum()return rights, len(labels)# 具体实例化
net = CNN().to(device)# 损失函数
criterion = nn.CrossEntropyLoss()# 优化器
optimizer = optim.Adam(net.parameters(), lr=0.001)for epoch in range(num_epochs):train_rights = []for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)  # 数据和标签移动到 GPUnet.train()output = net(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()right = accuracy(output, target)train_rights.append(right)if batch_idx % 100 == 0:net.eval()val_rights = []with torch.no_grad():  # 测试时不计算梯度以节省内存for (data, target) in test_loader:data, target = data.to(device), target.to(device)  # 测试数据也要移动到 GPUoutput = net(data)right = accuracy(output, target)val_rights.append(right)train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch:{}[{}/{} ({:.0f}%)]\t损失:{:.6f}\t 训练集准确率:{:.2f}%\t测试集正确率:{:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100 * batch_idx / len(train_loader),loss.data,100 * train_r[0].item() / train_r[1],  # 使用 item() 获取标量100 * val_r[0].item() / val_r[1]))

详细解释

数据准备

train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

train_dataset=datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(), download=True)

  • datasets.MNIST: 这是PyTorch提供的一个数据集类,用于加载MNIST手写数字数据集。MNIST数据集包含60,000张训练图像和10,000张测试图像,每张图像都是28x28像素的灰度图像,表示0到9之间的数字。

  • root='./data': 这是数据集下载和存储的根目录。在这个例子中,数据集将被下载到当前目录下的./data文件夹中。

  • train=True: 这个参数指定加载的是训练数据集。如果设置为False,则会加载测试数据集。

  • transform=transforms.ToTensor(): 这是数据预处理的一个步骤。transforms.ToTensor()将PIL Image或numpy数组转换为PyTorch张量(Tensor),并且将像素值从[0, 255]范围归一化到[0, 1]范围。

  • download=True: 如果数据集尚未下载到指定的root目录,这个参数会自动下载数据集。

test_dataset=datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor())

  • 这行代码与第一行类似,但train=False表示加载的是MNIST的测试数据集。测试数据集包含10,000张图像,用于评估模型的性能。

  • download=True没有在这里出现,因为测试数据集通常在下载训练数据集时一同下载。

train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size, shuffle=True)

  • torch.utils.data.DataLoader: 这是PyTorch中用于构建数据加载器的类。DataLoader可以将数据集打包成mini-batch的形式,便于模型训练。

  • dataset=train_dataset: 这是我们之前创建的训练数据集对象。DataLoader将从这个数据集中抽取数据。

  • batch_size=batch_size: 这是每个mini-batch的大小。batch_size是一个预定义的变量,通常在代码的其他地方定义。

  • shuffle=True: 这个参数表示在每次迭代时是否打乱数据集。打乱数据可以避免模型学习到数据顺序的偏差,从而提高训练效果。

test_loader=torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size, shuffle=True)

  • 这行代码与第三行类似,但使用的是测试数据集。测试数据集通常不需要在每次迭代时打乱数据,但在这个例子中,shuffle=True表示在每次测试时也会打乱数据集。这在某些情况下可能并不必要,但在代码中默认了这种设置。

数据的结构

这样讲可能只能知道一些概念的东西,不知道这个数据的具体结构,下面介绍一下数据的结构:

train_dataset结构

train_dataset 是一个 torchvision.datasets.MNIST 对象,专门用于处理 MNIST 数据集。它的主要属性和方法如下:

  • 数据结构:

    • train_dataset.data: 用于存储图像数据,是一个形状为 (N, 28, 28) 的张量,其中 N 是训练样本的数量(60,000)。每个图像都是28x28的灰度图。
    • train_dataset.targets: 用于存储对应的标签,是一个一维的张量,形状为 (N,),存放着每个图像的数字标签(0-9)。
  • 常用方法:

    • __getitem__(index): 用于获取指定索引(index)的样本和标签。返回的是一个元组 (image, label),其中 image 是转换为张量后的图像,label 是相应的数字标签。
    • __len__(): 返回数据集中样本的总数,这里是60,000。

train_loader结构

train_loader 是一个 torch.utils.data.DataLoader 对象,它为 train_dataset 提供划分和迭代的方式。主要特点包括:

  • 数据结构:

    • train_loader 本身并不存储数据。它是在遍历 train_dataset 时提供数据的工具。每次迭代都将返回一个 mini-batch 的数据
    • train_loader 将生成 mini-batch 的元组,通常每次迭代返回的是 (images, labels)
      • images: 一个形状为 (batch_size, 1, 28, 28) 的张量,其中 batch_size 是指定的每个 mini-batch 的大小。
      • labels: 一个形状为 (batch_size,) 的张量,存放当前 mini-batch 中每张图像的标签。
  • 常用方法:

    • __iter__(): 使得 train_loader 可以被用在 for 循环中,每次迭代都会返回一个新的 mini-batch。
    • __len__(): 返回 DataLoader 中的总 mini-batch 数量。这通常是数据集样本数除以 batch_size。

例如,假设 batch_size 的值为 64,那么在使用 train_loader 进行迭代时,每次迭代将得到:

  • images:

    • 形状为 (64, 1, 28, 28),表示64张28x28的灰度图像。每张图像的通道数为1,因为它是灰度图。
  • labels:

    • 形状为 (64,),表示这64张图像的标签,值在0到9之间。

DataLoader生成什么?

再来介绍一下torch.utils.data.DataLoader

DataLoader 生成的是可以迭代的 mini-batch 数据。具体来说,每次迭代时,DataLoader 会返回一个 mini-batch 的数据。这个 mini-batch 通常是一个元组 (images, labels),其中:

  • images 是一个形状为 (batch_size, C, H, W) 的张量,batch_size 是每次迭代的大小,C 是图像的通道数,H 和 W 是图像的高度和宽度。
  • labels 是一个形状为 (batch_size,) 的张量,表示相应图像的标签。

如何使用enumerate?

那么我们如何遍历这个元组呢?其实python不像C/C++,遍历需要用下标遍历,比如我们像便利一个列表我们直接是for idx in list:,这个idx得到的直接是list中的值,但是有时候我们希望知道遍历的下标(用于指示此次遍历到多少遍了)。这样就可以用"enumerate"

1. enumerate 的作用

enumerate 的主要作用是在遍历一个可迭代对象时,返回一个包含索引的元组。具体来说,它会对可迭代对象中的每一个元素配上一个索引值,从 0 开始(默认),然后逐个返回索引和元素。

2. 基本语法
enumerate(iterable, start=0)

  • iterable: 任何可迭代的对象,比如列表、元组、字符串等。
  • start: 索引的起始值,默认为 0,但你可以指定其他起始值。
3. 如何使用 enumerate
示例 1:基本用法
# 一个简单的列表
fruits = ['apple', 'banana', 'mango']# 使用 enumerate 遍历列表
for index, value in enumerate(fruits):print(f"Index: {index}, Value: {value}")
输出:
Index: 0, Value: apple
Index: 1, Value: banana
Index: 2, Value: mango

在这个例子中,enumerate 为每个列表元素提供了一个索引,从 0 开始。我们通过 for 循环同时获取了索引和元素值。

示例 2:指定起始索引

你可以通过 enumerate 的第二个参数指定索引的起始值。例如,如果你想让索引从 1 开始:

fruits = ['apple', 'banana', 'mango']# 使用 enumerate 遍历列表,索引从 1 开始
for index, value in enumerate(fruits, start=1):print(f"Index: {index}, Value: {value}")

输出:

Index: 1, Value: apple
Index: 2, Value: banana
Index: 3, Value: mango

4. 在 DataLoader 中使用 enumerate

在 PyTorch 中,我们通常使用 DataLoader 来加载数据集。在训练神经网络时,我们不仅需要遍历每个 mini-batch,还需要知道当前遍历到了第几个 batch。这时,enumerate 就非常有用。

示例:在 DataLoader 中使用 enumerate

假设我们有一个 DataLoader 加载了 MNIST 数据集,我们可以使用 enumerate 来同时获取 batch 的索引和数据。

import torch
from torchvision import datasets, transforms# 定义数据集和 DataLoader
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 使用 enumerate 遍历 DataLoader
for batch_idx, (images, labels) in enumerate(train_loader):print(f"Batch Index: {batch_idx}, Images shape = {images.shape}, Labels shape = {labels.shape}")# 如果需要在第一个 batch 后停止,可以加一个条件判断if batch_idx == 0:break

输出:

Batch Index: 0, Images shape = torch.Size([64, 1, 28, 28]), Labels shape = torch.Size([64])

在这个例子中:

  • batch_idx 是当前 mini-batch 的索引。
  • images 和 labels 是当前 mini-batch 的数据和标签。
  • enumerate(train_loader) 返回的 batch_idx 是从 0 开始的批次索引,而 images 和 labels 是对应的批次数据。

CNN的定义

class CNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(  # 输入大小(1,28,28)nn.Conv2d(in_channels=1,  # 灰度图out_channels=16,  # 得到的特征图的个数(也是使用卷积核的个数)kernel_size=5,  # 卷积核的大小stride=1,  # 步长padding=2,  # 边缘填充,如果希望得到的特征图大小和原来一样,那么padding=(kernel_size-1)/2 if stride = 1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 池化层操作区域(2*2),输出结果为(16,14,14))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 输出(32,7,7))self.out = nn.Linear(32 * 7 * 7, 10)  # 全连接层def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)output = self.out(x)return output

什么是类

为了讲清楚这个结构,我们从最基本的地方讲起。

1. 类和对象

在 Python 中,类是一种用于创建对象的蓝图或模板。类定义了一组属性和方法,这些属性和方法可以被实例化后的对象所使用。

定义类
class Dog:def __init__(self, name, age):self.name = nameself.age = agedef bark(self):return "Woof!"

  • __init__ 方法: 这是类的构造函数,用于初始化对象的属性。
  • self: 表示类的实例对象本身。
创建对象
my_dog = Dog("Buddy", 3)
print(my_dog.name)  # 输出: Buddy
print(my_dog.bark())  # 输出: Woof!

2. 继承

继承是面向对象编程的一个重要特性,它允许一个类(子类)继承另一个类(父类)的属性和方法。子类可以重用父类的代码,并且可以在不修改父类的情况下添加新的功能。

定义父类
class Animal:def __init__(self, name):self.name = namedef speak(self):return "I am an animal."

定义子类
class Dog(Animal):def __init__(self, name, age):super().__init__(name)  # 调用父类的构造函数self.age = agedef speak(self):return "Woof!"

  • super(): 用于调用父类的方法。
  • super().__init__(name): 调用父类 Animal 的构造函数,初始化 name 属性。
创建子类对象
my_dog = Dog("Buddy", 3)
print(my_dog.name)  # 输出: Buddy
print(my_dog.age)  # 输出: 3
print(my_dog.speak())  # 输出: Woof!

3. super() 函数

super() 函数用于调用父类的方法。它在多重继承中特别有用,因为它可以确保正确的方法解析顺序(Method Resolution Order, MRO)。

super() 的基本用法
class Animal:def __init__(self, name):self.name = namedef speak(self):return "I am an animal."class Dog(Animal):def __init__(self, name, age):super().__init__(name)  # 调用父类的构造函数self.age = agedef speak(self):return "Woof!"

  • super().__init__(name): 调用父类 Animal 的构造函数,初始化 name 属性。
super() 的多重继承
class Animal:def speak(self):return "I am an animal."class Mammal(Animal):def speak(self):return "I am a mammal."class Dog(Mammal):def speak(self):return super().speak() + " And I bark."my_dog = Dog()
print(my_dog.speak())  # 输出: I am a mammal. And I bark.

在这个例子中,super().speak() 首先调用 Mammal 类的 speak() 方法,然后在其基础上添加新的内容。

代码整体结构解释

1. 类定义与父类初始化
class CNN(nn.Module):def __init__(self):super().__init__()

  • 类定义CNN 类继承自 nn.Module
  • 父类初始化super().__init__() 调用父类 nn.Module 的构造函数确保 nn.Module 中的所有属性和方法被正确初始化
2. 第一个卷积层 (self.conv1)
self.conv1 = nn.Sequential(  # 输入大小(1,28,28)nn.Conv2d(in_channels=1,  # 灰度图out_channels=16,  # 得到的特征图的个数(也是使用卷积核的个数)kernel_size=5,  # 卷积核的大小stride=1,  # 步长padding=2,  # 边缘填充,如果希望得到的特征图大小和原来一样,那么padding=(kernel_size-1)/2 if stride = 1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 池化层操作区域(2*2),输出结果为(16,14,14)
)

  • nn.Sequential: 是一个容器,允许我们将多个层按顺序组合起来。
  • nn.Conv2d: 2D 卷积层,用于提取图像特征。
    • in_channels=1: 输入通道数为 1(灰度图像)。
    • out_channels=16: 输出通道数为 16,即使用 16 个卷积核,得到的特征图有 16 个。
    • kernel_size=5: 卷积核大小为 5x5。
    • stride=1: 卷积操作的步长为 1。
    • padding=2: 边缘填充 2 个像素,确保输出特征图大小与输入一致。
  • nn.ReLU: 激活函数,引入非线性。
  • nn.MaxPool2d: 最大池化层,用于减少特征图的空间尺寸。
    • kernel_size=2: 池化窗口大小为 2x2,步长默认为 2,输出特征图大小为 (16, 14, 14)。
3. 第二个卷积层 (self.conv2)
self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 输出(32,7,7)
)

  • nn.Conv2d: 第二个卷积层,输入通道数为 16(来自第一个卷积层的输出),输出通道数为 32。
    • in_channels=16: 输入通道数为 16。
    • out_channels=32: 输出通道数为 32。
    • kernel_size=5: 卷积核大小为 5x5。
    • stride=1: 卷积操作的步长为 1。
    • padding=2: 边缘填充 2 个像素,确保输出特征图大小与输入一致。
  • nn.ReLU: 激活函数,引入非线性。
  • nn.MaxPool2d: 最大池化层,用于减少特征图的空间尺寸。
    • kernel_size=2: 池化窗口大小为 2x2,步长默认为 2,输出特征图大小为 (32, 7, 7)。
4. 全连接层 (self.out)
self.out = nn.Linear(32 * 7 * 7, 10)  # 全连接层

  • nn.Linear: 全连接层,将卷积层输出的特征图转换为最终的分类结果。
    • 32 * 7 * 7: 输入特征的大小,32 个特征图,每个大小为 7x7。
    • 10: 输出大小为 10,对应 10 个类别的分类任务。
5. 前向传播函数 (forward)
def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)output = self.out(x)return output

  • forward: 前向传播函数,定义了数据在神经网络中的流动路径。
    • self.conv1(x): 将输入数据 x 通过第一个卷积层 conv1
    • self.conv2(x): 将经过 conv1 处理后的数据通过第二个卷积层 conv2
    • x.view(x.size(0), -1): 将卷积层的输出展平为一维向量,x.size(0) 是批量大小,-1 表示自动计算展平后的长度。
    • self.out(x): 将展平后的数据通过全连接层 out,得到最终的输出结果。
    • return output: 返回输出结果。

各个方法的具体解释

1. nn.Conv2d 的参数

nn.Conv2d 用于创建二维卷积层,它的构造函数有以下参数:

nn.Conv2d(in_channels,      # 输入通道数out_channels,     # 输出通道数kernel_size,      # 卷积核的大小stride=1,         # 步长,默认为1padding=0,        # 填充,默认为0dilation=1,       # 空洞卷积的膨胀系数,默认为1groups=1,         # 分组卷积的数量,默认为1bias=True,        # 是否使用偏置
)

  • in_channels (int): 输入的通道数。对于灰度图像一般为 1,对于 RGB 图像一般为 3
  • out_channels (int): 输出的通道数,即卷积核的个数
  • kernel_size (int 或 tuple): 卷积核的尺寸。可以是一个整数(表示方形卷积核),也可以是一个二元组,表示长和宽,例如 (3, 5)
  • stride (int 或 tuple): 步长,默认为 1。可以是一个整数(表示在两个方向上的步长相同),也可以是一个二元组,例如 (1, 2)(表示在两个方向上不同)
  • padding (int 或 tuple): 在输入的边缘补充的零的数量(加几圈0)。也可以是一个整数或二元组,类似于 stride 的情况。
  • dilation (int 或 tuple): 控制卷积核元素之间的间隔,默认为 1。
  • groups (int): 控制卷积的分组。默认为 1 谷歌母公司的 MobileNet 等模型使用分组卷积时将其设为 2 或更多。
  • bias (bool): 是否使用偏置,默认为 True。
2. nn.ReLU 的参数
nn.ReLU(inplace=False)  # 'inplace' 是一个可选参数

  • inplace (bool): 是否进行原地操作。如果为 True,ReLU 会直接改变输入的值,可以节省内存,但不适合某些场景。默认为 False。
3. nn.MaxPool2d 的参数
nn.MaxPool2d(kernel_size,      # 池化窗口的大小stride=None,      # 步长,默认为 kernel_sizepadding=0,        # 填充大小,默认为0dilation=1,       # 空洞卷积的膨胀系数,默认为1return_indices=False,  # 是否返回池化 indicesceil_mode=False   # 是否向上取整
)

  • kernel_size (int 或 tuple): 池化窗口的大小,可以是一个整数或二元组
  • stride (int 或 tuple): 步长,默认为 kernel_size
  • padding (int 或 tuple): 在输入的边缘补充的零的数量。
  • dilation (int 或 tuple): 控制池化的间隔,默认为 1。
  • return_indices (bool): 如果为 True,返回每个池化区域的索引。默认为 False。
  • ceil_mode (bool): 若为 True,使用向上取整。在输入图像尺寸非常小的情况下可能有帮助。默认为 False。
4. nn.Linear 的参数
nn.Linear(in_features, out_features, bias=True)

  • in_features (int): 输入的特征数量,即上一层的输出维度。
  • out_features (int): 输出的特征数量,即本层的输出维度。
  • bias (bool): 是否使用偏置,默认为 True。
5. view 方法的参数
x.view(size)  # size 可以是一个整数或元组

  • size (int 或 tuple): 用于重新定义张量的形状,第一个维度通常是 batch size,有时你可以使用 -1 来自动推导某个维度的大小。例如,x.view(x.size(0), -1) 中的 -1 表示根据其他维度的大小自动计算它的值。

在PyTorch中,张量的 .size() 方法返回一个 torch.Size 对象,它包含了张量的形状信息。具体来说,.size(0) 是用来获取张量的第一维度的大小。

假设你有一个形状为 (batch_size, channels, height, width) 的四维张量 x,那么在不同维度下使用 .size() 会返回不同的值:

  • x.size(0):返回第一个维度的大小,即 batch_size
  • x.size(1):返回第二个维度的大小,即 channels
  • x.size(2):返回第三个维度的大小,即 height
  • x.size(3):返回第四个维度的大小,即 width
import torch# 创建一个四维张量
x = torch.randn(32, 3, 64, 64)  # 批量大小为 32,通道数为 3,高度和宽度均为 64# 获取各维度的大小
batch_size = x.size(0)  # 32
channels = x.size(1)    # 3
height = x.size(2)      # 64
width = x.size(3)       # 64print(f"Batch Size: {batch_size}")
print(f"Channels: {channels}")
print(f"Height: {height}")
print(f"Width: {width}")

输出:::Batch Size: 32 Channels: 3 Height: 64 Width: 64

训练过程

训练前准备

# 具体实例化
net = CNN().to(device)# 损失函数
criterion = nn.CrossEntropyLoss()# 优化器
optimizer = optim.Adam(net.parameters(), lr=0.001)

训练过程


for epoch in range(num_epochs):train_rights = []for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)  # 数据和标签移动到 GPUnet.train()output = net(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()right = accuracy(output, target)train_rights.append(right)if batch_idx % 100 == 0:net.eval()val_rights = []with torch.no_grad():  # 测试时不计算梯度以节省内存for (data, target) in test_loader:data, target = data.to(device), target.to(device)  # 测试数据也要移动到 GPUoutput = net(data)right = accuracy(output, target)val_rights.append(right)train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch:{}[{}/{} ({:.0f}%)]\t损失:{:.6f}\t 训练集准确率:{:.2f}%\t测试集正确率:{:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100 * batch_idx / len(train_loader),loss.data,100 * train_r[0].item() / train_r[1],  # 使用 item() 获取标量100 * val_r[0].item() / val_r[1]))

1. 主要循环 for epoch in range(num_epochs)

这个循环遍历所有的训练周期(epoch)。每个 epoch 代表了一次完整的训练数据集遍历。

2. 初始化 train_rights 列表

train_rights = []

train_rights 用于存储每个批次(batch)的准确率信息。准确率信息通常以元组的形式存储,比如 (正确预测的数量, 总样本数量)

3. 训练数据加载与处理

for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)

  • train_loader: 一个 PyTorch DataLoader 对象,用于加载训练数据批次。
  • batch_idx: 当前批次的索引。
  • data: 输入数据(如图像)。
  • target: 目标标签(如分类标签)。
  • data.to(device) 和 target.to(device): 将数据和标签移动到指定的设备(如 GPU)。

4. 设置模型为训练模式

net.train()

net.train() 将模型设置为训练模式。这会影响某些层的行为,如 Dropout 和 BatchNorm。在训练模式下,Dropout 会随机丢弃神经元,BatchNorm 会更新其内部统计量。

5. 前向传播

output = net(data)

使用输入数据进行前向传播,得到模型的输出结果 output

6. 计算损失

loss = criterion(output, target)

使用预定义的损失函数 criterion 计算模型输出 output 和目标标签 target 之间的损失值。

7. 清零梯度

optimizer.zero_grad()

在每次反向传播之前,需要清零梯度,以避免梯度累积。

8. 反向传播

loss.backward()

调用 backward() 方法进行反向传播,计算损失函数对模型参数的梯度。

9. 更新模型参数

optimizer.step()

使用优化器更新模型参数,根据计算出的梯度调整参数。

10. 计算训练准确率

right = accuracy(output, target)
train_rights.append(right)

计算当前批次的准确率 right,并将其添加到 train_rights 列表中。

11. 验证模型性能

if batch_idx % 100 == 0:net.eval()val_rights = []with torch.no_grad():for (data, target) in test_loader:data, target = data.to(device), target.to(device)output = net(data)right = accuracy(output, target)val_rights.append(right)

  • net.eval(): 将模型设置为评估模式。在评估模式下,Dropout 和 BatchNorm 等层的行为会有所不同。
  • torch.no_grad(): 在评估时禁用梯度计算,以节省内存和计算资源。
  • test_loader: 用于加载验证数据批次。
  • val_rights: 用于存储验证批次(batch)的准确率信息。

12. 计算并打印训练和验证的准确率

train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch:{}[{}/{} ({:.0f}%)]\t损失:{:.6f}\t 训练集准确率:{:.2f}%\t测试集正确率:{:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100 * batch_idx / len(train_loader),loss.data,100 * train_r[0].item() / train_r[1],  # 使用 item() 获取标量100 * val_r[0].item() / val_r[1]
))

  • train_r: 计算训练集的总正确预测数量和总样本数量。
  • val_r: 计算验证集的总正确预测数量和总样本数量。
  • print: 输出当前 epoch、批次数、损失值、训练集准确率和验证集准确率。

相关文章:

基于pytorch搭建CNN

先上代码 import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np import pandas as pd import matplotlibmatplotlib.use(tkA…...

C#实现与Windows服务的交互与控制

在C#中,与Windows服务进行交互和控制通常涉及以下几个步骤: 创建Windows服务:首先,需要创建一个Windows服务项目。可以使用Visual Studio中的“Windows 服务 (.NET Framework)”项目模板来创建Windows服务。 配置服务控制事件&am…...

Java和Ts构造函数的区别

java中子类在使用有参构造创建对象的时候不必要必须调用父类有参构造 而js则必须用super()调用父类的有参构造,即使用不到也必须传递 Java 中的处理方式 可选择性参数: 在 Java 中,当子类使用父类的有参构造方法创建对象时,可以只传递需要的参数。如果父…...

植物健康,Spring Boot来助力

3系统分析 3.1可行性分析 通过对本植物健康系统实行的目的初步调查和分析,提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本植物健康系统采用SSM框架,JAVA作为开发语言&#…...

百度文心一言接入流程-java版

百度文心一言接入流程-java版 一、准备工作二、API接口调用-java三、百度Prompt工程参考资料: 百度文心一言:https://yiyan.baidu.com/百度千帆大模型:https://qianfan.cloud.baidu.com/百度千帆大模型文档:https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html千tokens…...

Java 11 新特性深度解析与应用实践

Java 作为一种广泛应用的编程语言,不断演进以满足开发者日益增长的需求和适应技术的发展趋势。Java 11 带来了一系列重要的新特性和改进,这些变化不仅提升了语言的性能和功能,还为开发者提供了更好的开发体验和工具。本文将深入探讨 Java 11 …...

druid 连接池监控报错 Sorry, you are not permitted to view this page.本地可以,发布正式出错

简介: druid 连接池监控报错 Sorry, you are not permitted to view this page. 使用Druid连接池的时候,遇到一个奇怪的问题,在本地(localhost)可以直接打开Druid连接池监控,在其他机器上打开会报错&#…...

[RN与H5] 加载线上H5通信失败问题记录(启动本地H5服务OK)

RT: nextjs项目 在本地启动H5服务, 本地开发都OK 发布到线上后, 效果全无, 经排查发现, 写了基本配置的js脚本在挂载时机上的差异导致 根本原因是...

electron 打包

安装及配置 安装electron包以及electron-builder打包工具 # 安装 electron cnpm install --save-dev electron # 安装打包工具 cnpm install electron-builder -D 参考的package.json文件 其中description和author为必填项目 {"name": "appfile",&qu…...

ChatGLM-6B和Prompt搭建专业领域知识问答机器人应用方案(含完整代码)

目录 ChatGLM-6B部署 领域知识数据准备 领域知识数据读取 知识相关性匹配 Prompt提示工程 领域知识问答 完整代码 本文基于ChatGLM-6B大模型和Pompt提示工程搭建医疗领域知识问答机器人为例。 ChatGLM-6B部署 首先需要部署好ChatGLM-6B,参考 ChatGLM-6B中英双…...

虚拟机配置静态IP地址(人狠话不多简单粗暴)

1.先找到以下位置: 2. 虚拟机中执行vi /etc/sysconfig/network-scripts/ifcfg-ens33 根据上图信息修改配置文件内容: 静态IP地址设置不超过255就行,我这里弄得100,没毛病。 3.修改并保存文件后,重启网络执行&#…...

Android token JJWT

在Android开发领域,JJWT(Java JWT,即Java Json Web Token)库是一个流行的工具,用于处理JSON Web Tokens(JWTs)。JWT是一种轻量级的、自包含的、基于JSON的用于双方之间安全传输信息的简洁的、UR…...

动态规划<一>初识动态规划

目录 认识动态规划 LeetCodeOJ练习 斐波那契数列模型 认识动态规划 1.动态规划是一种用于解决优化问题的算法策略。 2.它的核心原理是把一个复杂的问题分解为一系列相互关联的子问题。通过先求解子问题,并且记录这些子问题的解(通常用一个表格之类的…...

【AIGC】ChatGPT提示词Prompt精确控制指南:Scott Guthrie的建议详解与普通用户实践解析

博客主页: [小ᶻZ࿆] 本文专栏: AIGC | ChatGPT 文章目录 💯前言💯斯科特古斯里(Scott Guthrie)的建议解读人机交互设计的重要性减轻用户认知负担提高Prompt的易用性结论 💯普通用户视角的分析普通用户…...

2024年10月24日随笔

1024程序员节啊,现在已经是晚上的十点半了,我还在实验室里没走,刚把力扣的每日一题写完,好忙啊,好忙啊,好忙啊,为什么都大三了我还不能做自己的事情,今天老师开会说要给互联网加大赛…...

怎么做系统性能优化

对于软件或系统的性能优化,可以采取多种措施来提高效率和响应速度。这里为您列举一些常见的方法: 1. 代码优化:检查并优化算法复杂度,减少不必要的计算。使用更高效的数据结构和算法。 2. 数据库优化: •索引优化&…...

负载均衡:四层与七层

负载均衡建立在现在网络基础之上,提供一种廉价透明有效的方式扩展网络设备和服务器带宽、增加吞吐量、加强网络数据处理能力、提高网络的灵活性和可用性。负载均衡可分为七层负载与四层负载。 四层负载(目标地址与端口交换) 主要通过报文中…...

【Ubuntu】服务器系统重装SSHxrdpcuda

本文作者: slience_me Ubuntu系统重装操作合集 文章目录 Ubuntu系统重装操作合集1.1 系统安装:1.2 安装openssh-server更新系统包安装OpenSSH服务器检查SSH服务的状态配置防火墙以允许SSH测试SSH连接配置SSH(可选) 1.3 安装远程连…...

ChatGPT的模型训练入门级使用教程

ChatGPT 是由 OpenAI 开发的一种自然语言生成模型,基于 Transformer 架构的深度学习技术,能够流畅地进行对话并生成有意义的文本内容。它被广泛应用于聊天机器人、客户服务、内容创作、编程助手等多个领域。很多人对如何训练一个类似 ChatGPT 的语言模型…...

【OS】2.1.2 进程的状态与转换_进程的组织

✨ Blog’s 主页: 白乐天_ξ( ✿>◡❛) 🌈 个人Motto:他强任他强,清风拂山冈! 🔥 所属专栏:C深入学习笔记 💫 欢迎来到我的学习笔记! 一、进程的状态 1.1.创建态 ……的…...

和为 n 的完全平方数的最少数量

给你一个整数 n ,返回 和为 n 的完全平方数的最少数量 。 完全平方数 是一个整数,其值等于另一个整数的平方;换句话说,其值等于一个整数自乘的积。例如,1、4、9 和 16 都是完全平方数,而 3 和 11 不是。 示…...

Hallo2 长视频和高分辨率的音频驱动的肖像图像动画 (数字人技术)

HALLO2: LONG-DURATION AND HIGH-RESOLUTION AUDIO-DRIVEN PORTRAIT IMAGE ANIMATION 论文:https://arxiv.org/abs/2410.07718 代码:https://github.com/fudan-generative-vision/hallo2 模型:https://huggingface.co/fudan-generative-ai/h…...

如何在Debian 8上使用Let‘s Encrypt保护Apache

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 简介 本教程将向您展示如何在运行 Apache 作为 Web 服务器的 Debian 8 服务器上设置来自 Let’s Encrypt 的 TLS/SSL 证书。我们还将介…...

百科知识|选购指南

百科知识||选购指南 百科知识选购指南茶叶分类茶叶的味道来源茶叶制作步骤名茶其他一些茶叶的知识 百科知识 选购指南 茶叶 分类 茶叶种类: 六大茶类完美分析介绍!茶友推荐收藏 (aboxtik.com) 1.绿茶(发酵率0%) 2.白茶(发酵率…...

Go 语言基础教程:4.常量的使用

在这篇教程中,我们将通过一个简单的 Go 语言程序来学习常量的声明和使用。以下是我们要分析的代码: package mainimport ("fmt""math" )const s string "constant"func main() {fmt.Println(s)const n 500000000const …...

centos服务器重启后,jar包自启动

第一种方法: systemctl服务自启动 在/usr/lib/systemd/system目录下,创建service:start_jar.servie [Unit] DescriptionYour Java Application as a Service Afternetwork.target[Service] Userroot Typesimple ExecStart/usr/bin/java -j…...

华为云实战杂记

配置nginx服务器 首先我们拿到一台服务器时,并不知道系统是否存在Nginx我们可以在Linux命令行执行如下命令查看 find / -name nginx* find / -name nginx* 查找所有名字以nginx开头的文件或者目录,我们看看系统里面都有哪些文件先,这样可以快…...

Lesson10---list

Lesson10—list 第10章 c的list的使用和实现 文章目录 Lesson10---list前言一、list的初始化二、list的遍历1.迭代器2.范围for 三、list常用的内置函数1.sort(慎用)2.unique3.reverse4.merge5.splice 四、模拟实现1.基本框架2.构造函数3.push_back4. 遍…...

ASP.NET Core 8.0 中使用 Hangfire 调度 API

在这篇博文中,我们将引导您完成将 Hangfire 集成到 ASP.NET Core NET Core 项目中以安排 API 每天运行的步骤。Hangfire 是一个功能强大的库,可简化 .NET 应用程序中的后台作业处理,使其成为调度任务的绝佳选择。继续阅读以了解如何设置 Hang…...

查看linux的版本

在 Linux 系统中,有多种方法可以查看当前系统的版本信息。以下是一些常用的方法: 1. 使用 uname 命令 uname 命令可以显示系统的内核版本和其他相关信息。 uname -a这个命令会输出类似如下的信息: Linux hostname 5.4.0-88-generic #99-U…...

工业设计好找工作吗/湖南seo优化首选

FFmpeg for XP(x86) 2016-03-23 static 静态编译适用于32位XP系统,能加的扩展都加了,结果文件大小非常大. 最新版加了不少视频和音频滤镜. ffmpeg.20160323.for.XP.x86.static.7z./configure --enable-static --disable-shared --enable-gpl --enable-version3 --enable-nonfre…...

轻淘客网站建设/百度识图在线识别

转自: http://blog.csdn.net/liu1347508335/article/details/51097761 Objective-C中的音乐播放大多用AVAudioPlayer,它有很多优点: (1)可以播放任意长度音乐; (2)可以循环播放; …...

建网站域名注册后需要/百度做免费推广的步骤

搞FPGA,SRAM是必过的一关,毕竟芯片最核心的就是运算与存储,本篇文章属于转载,详细介绍了标准工艺下的SRAM工作原理,一般工艺库或者实例化的SRAM使用的就是这种标准SRAM,有地址译码器,地址线&…...

广州app开发公司排名/怎么制作seo搜索优化

pytest自动化测试框架之入门知识pytest自动化测试框架总结1.入门安装 pytest1.1资源获取1.2运行pytest1.3运行单个案例1.4使用命令行pytest自动化测试框架总结 1.入门 pytest 是一个使构建简单和可伸缩的测试变得容易的框架。测试具有表达性和可读性,不需要样板代…...

定制型网站开发/seo网站优化培训怎么样

文章目录 1 结论2 详解2.1 日期间隔 NumToYMInterval()2.2 时间间隔 NumToDSInterval() 3 扩展 1 结论 日期间隔函数 NumToYMInterval(),间隔周期 YEAR 年,MONTH 月 时间间隔函数 NumToDSInterval(),间隔周期 DAY 天,HOUR 小时&a…...

海尔网站建设策划书/正规seo一般多少钱

1.对c语言的看法在上大学之前,我对这个专业仅仅的认知是学电脑的,对编程来说更是一无所知,而我选择计算机专业完全是因为我从小就喜欢玩电脑,仅此而已。记得小时候还不会拼音和英文的时候,我玩的第一个游戏就是侠盗飞车…...