深度学习笔记_4、CNN卷积神经网络+全连接神经网络解决MNIST数据
1、首先,导入所需的库和模块,包括NumPy、PyTorch、MNIST数据集、数据处理工具、模型层、优化器、损失函数、混淆矩阵、绘图工具以及数据处理工具。
import numpy as np
import torch
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import pandas as pd
2、设置超参数,包括训练批次大小、测试批次大小、学习率和训练周期数。
# 设置超参数
train_batch_size = 64
test_batch_size = 64
learning_rate = 0.001
num_epochs = 10
3、创建数据转换管道,将图像数据转换为张量并进行标准化。
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])
4、下载和预处理MNIST数据集,分为训练集和测试集。
# 下载和预处理数据集
train_dataset = mnist.MNIST('data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('data', train=False, transform=transform)
5、创建用于训练和测试的数据加载器,以便有效地加载数据。
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
6、定义了一个简单的CNN模型,包括两个卷积层和两个全连接层。
# 定义CNN模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=5)self.conv2 = nn.Conv2d(32, 64, kernel_size=5)self.fc1 = nn.Linear(1024, 256)self.fc2 = nn.Linear(256, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2(x), 2))x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)
7、初始化模型、优化器和损失函数。
# 初始化模型、优化器和损失函数
model = CNN()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
8、准备用于记录训练和测试过程中损失和准确率的列表。
# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []
9、进入训练循环,遍历每个训练周期。在每个训练周期内,进入训练模式,遍历训练数据批次,计算损失、反向传播并更新模型参数,同时记录训练损失和准确率。
for epoch in range(num_epochs):model.train()train_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()# 计算训练准确率_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 计算平均训练损失和训练准确率train_loss /= len(train_loader)train_accuracy = 100. * correct / totaltrain_losses.append(train_loss)train_accuracies.append(train_accuracy) # 记录训练准确率# 测试模型model.eval()test_loss = 0.0correct = 0all_labels = []all_preds = []with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_labels.extend(target.numpy())all_preds.extend(pred.numpy())
10、在每个训练周期结束后,进入测试模式,遍历测试数据批次,计算测试损失和准确率,同时记录它们。打印每个周期的训练和测试损失以及准确率。
# 计算平均测试损失和测试准确率test_loss /= len(test_loader)test_accuracy = 100. * correct / len(test_loader.dataset)test_losses.append(test_loss)test_accuracies.append(test_accuracy)print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
11、losses、acces、eval_losses、eval_acces保存到TXT文件
# 保存训练结果
data = np.column_stack((train_losses,test_losses,train_accuracies, test_accuracies))
np.savetxt("results.txt", data)
12、绘制Loss、ACC图像
# 绘制Loss曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()# 绘制Accuracy曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_accuracies, label='Train Accuracy', color='red') # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.savefig('accuracy_curve.png')
plt.show()
13、绘制混淆矩阵图像
# 计算混淆矩阵
confusion_mat = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()
相关文章:
深度学习笔记_4、CNN卷积神经网络+全连接神经网络解决MNIST数据
1、首先,导入所需的库和模块,包括NumPy、PyTorch、MNIST数据集、数据处理工具、模型层、优化器、损失函数、混淆矩阵、绘图工具以及数据处理工具。 import numpy as np import torch from torchvision.datasets import mnist import torchvision.transf…...
高效的开发流程搭建
目录 1. 搭建 AI codebase 环境kaggle的服务器1. 搭建 AI codebase 环境 python 、torch 以及 cuda版本,对AI的影响最大。不同的版本,可能最终计算出的结果会有区别。 硬盘:PCIE转SSD的卡槽,, GPU: 软件源: Anaconda: 一定要放到固态硬盘上。 VS code 的 debug功能…...
浅谈OV SSL 证书的优势
随着网络威胁日益增多,保护网站和用户安全已成为每个企业和组织的重要任务。在众多SSL证书类型中,OV(Organization Validation)证书以其独特的优势备受关注。让我们深入探究OV证书的优势所在,为网站安全搭建坚实的防线…...
一篇博客学会系列(3) —— 对动态内存管理的深度讲解以及经典笔试题的深度解析
目录 动态内存管理 1、为什么存在动态内存管理 2、动态内存函数的介绍 2.1、malloc和free 2.2、calloc 2.3、realloc 3、常见的动态内存错误 3.1、对NULL指针的解引用操作 3.2、对动态开辟空间的越界访问 3.3、对非动态开辟内存使用free释放 3.4、使用free释放一块动态…...
【C++ techniques】虚化构造函数、虚化非成员函数
constructor的虚化 virtual function:完成“因类型而异”的行为;constructor:明确类型时构造函数;virtual constructor:视其获得的输入,可产生不同的类型对象。 //假如写一个软件,用来处理时事…...
蓝牙核心规范(V5.4)11.6-LE Audio 笔记之初识音频位置和通道分配
专栏汇总网址:蓝牙篇之蓝牙核心规范学习笔记(V5.4)汇总_蓝牙核心规范中文版_心跳包的博客-CSDN博客 爬虫网站无德,任何非CSDN看到的这篇文章都是盗版网站,你也看不全。认准原始网址。!!! 音频位置 在以前的每个蓝牙音频规范中,只有一个蓝牙LE音频源和一个蓝牙LE音频接…...
mysql双主+双从集群连接模式
架构图: 详细内容参考: 结果展示: 178.119.30.14(主) 178.119.30.15(主) 178.119.30.16(从) 178.119.30.17(从)...
嵌入式中如何用C语言操作sqlite3(07)
sqlite3编程接口非常多,对于初学者来说,我们暂时只需要掌握常用的几个函数,其他函数自然就知道如何使用了。 数据库 本篇假设数据库为my.db,有数据表student。 nonamescore4嵌入式开发爱好者89.0 创建表格语句如下: CREATE T…...
RandomForestClassifier 与 GradientBoostingClassifier 的区别
RandomForestClassifier(随机森林分类器)和GradientBoostingClassifier(梯度提升分类器)是两种常用的集成学习方法,它们之间的区别分以下几点。 1、基础算法 RandomForestClassifier:随机森林分类器是基于…...
计组——I/O方式
一、程序查询方式 CPU不断轮询检查I/O控制器中“状态寄存器”,检测到状态为“已完成”之后,再从数据寄存器取出输入数据。 过程: 1.CPU执行初始化程序,并预置传送参数;设置计数器、设置数据首地址。 2. 向I/O接口发…...
jsbridge实战2:Swift和h5的jsbridge通信
[[toc]] demo1: 文本通信 h5 -> app 思路: h5 全局属性上挂一个变量app 接收这个变量的内容关键API: navigation代理 navigationAction.request.url?.absoluteString // 这个变量挂载在 request 的 url 上 ,在浏览器实际无法运行,因…...
集合原理简记
HashMap 无论在构造函数是否指定数组长度,进行的都是延迟初始化 构造函数作用: 阈值:threshold,每次<<1 ,数组长度 负载因子 无参构造:设置默认的负载因子 有参:可以指定初始容量或…...
机器学习的超参数 、训练集、归纳偏好
一、介绍 超参数(Hyperparameters)和验证集(Validation Set)是机器学习中重要的概念,用于调整模型和评估其性能。 超参数: 超参数是在机器学习模型训练过程中需要手动设置的参数,而不是从数据…...
Leetcode1071. 字符串的最大公因子(三种方法,带详细解析)
Leetcode1071. 字符串的最大公因子 对于字符串 s 和 t,只有在 s t … t(t 自身连接 1 次或多次)时,我们才认定 “t 能除尽 s”。 给定两个字符串 str1 和 str2 。返回 最长字符串 x,要求满足 x 能除尽 str1 且 x 能…...
如何像人类一样写HTML之图像标签,超链接标签与多媒体标签
文章目录 前言一、图像标签1.1 什么是图像标签?2.2 如何使用图像标签? 二、超链接标签2.1 什么是超链接标签?2.2 如何使用超链接标签? 三、多媒体标签3.1 什么是多媒体标签?3.2 如何使用多媒体audio标签?3.…...
1300*C. Rumor(并查集贪心)
解析: 并查集,求每个集合的最小费用。 每次合并集合的时候,根节点保存当前集合最小的费用。 #include<bits/stdc.h> using namespace std; #define int long long const int N1e55; int n,m,a[N],p[N],cnt[N]; int find(int x){retur…...
python实用小代码(数据分析向)
一、python实现31省市名称替换 首先定义了一个字典province_dict,其中包含每个省市的无缩写名称与其标准名称之间的映射。然后,我们使用map()函数将每个省市名称从无缩写名称转换为标准名称,并将结果存储在新列省市标准名称中。 province_di…...
【oncmdmsg 鼠标】2023/8/19 上午9:50:14
2023/8/19 上午9:50:14 oncmdmsg 鼠标 2023/8/19 上午9:50:22 抱歉,您提到的 “oncmdmsg” 和 “鼠标” 是什么?请提供更多上下文信息,我将尽力帮助您。 2023/8/19 上午9:51:43 OnCmdMsg 2023/8/19 上午9:52:21 “OnCmdMsg” 是一个在 MFC (Microsoft Foundation Cla…...
插入排序:简单而有效的排序方法
在计算机科学中,排序算法是一个重要且常见的主题,它们用于对数据进行有序排列。插入排序(Insertion Sort)是其中一个简单但有效的排序算法。本文将详细解释插入排序的原理和步骤,并提供Java语言的实现示例。 插入排序的…...
OpenGL之光照贴图
我们需要拓展之前的系统,引入漫反射和镜面光贴图(Map)。这允许我们对物体的漫反射分量和镜面光分量有着更精确的控制。 漫反射贴图 我们希望通过某种方式对物体的每个片段单独设置漫反射颜色。我们仅仅是对同样的原理使用了不同的名字:其实都是使用一张覆盖物体的图像,让我…...
隐私交易成新刚需,Unijoin 凭什么优势杀出重围?
随着区块链技术的普及和发展,全球加密货币用户在持续增长,根据火币研究院公布的数据,2022年全球加密用户已达到 3.2亿人,目前全球人口总数超过了 80亿,加密货币用户渗透率已达到了 4%。 尤其是在 2020 年开启的 DeFi 牛…...
小谈设计模式(12)—迪米特法则
小谈设计模式(12)—迪米特法则 专栏介绍专栏地址专栏介绍 迪米特法则核心思想这里的“朋友”指当前对象本身以参数形式传入当前对象的对象当前对象的成员变量直接引用的对象目标 Java程序实现程序分析 总结 专栏介绍 专栏地址 link 专栏介绍 主要对目…...
Foxit PDF
Foxit PDF 福昕PDF 软件,可以很好的编辑PDF文档。 调整PDF页面大小 PDF文档中,一个页面大,一个页面小 面对这种情况,打开Foxit PDF 右键单击需要调整的页面,然后选择"调整页面大小". 可以选择…...
《Python趣味工具》——ppt的操作(刷题版)
前面我们对PPT进行了一定的操作,并将其中的文字提取到了word文档中。现在就让我们来刷几道题巩固巩固吧! 文章目录 1. 查看PPT(上)2. 查看PPT(中)3. 查看PPT(下)4. PPT的页码5. 大学…...
实战型开发--3/3,clean code
编程的纯粹 hmmm,一开始在这个环节想聊一些具体的点,其实也就是《clean code》这本书中的点,但这个就还是更流于表面; 因为编码的过程,就更接近于运动员打球,艺术家绘画,棋手下棋的过程&#x…...
家用无线路由器如何用网线桥接解决有些房间无线信号覆盖不好的问题(低成本)
环境 光猫ZXHN F677V9 水星MW325R 无线百兆路由器 100M宽带,2.4G无线网络 苹果手机 安卓平板电脑 三室一厅94平 问题描述 家用无线路由器如何用网线桥接解决有些房间无线信号不好问题低成本解决,无线覆盖和漫游 主路由器用的运营商的光猫自带无…...
【Golang】网络编程
网络编程 网络模型介绍 OSI七层网络模型 在软件开发中我们使用最多的是上图中将互联网划分为五个分层的模型: 物理层数据链路层网络层传输层应用层 物理层 我们的电脑要与外界互联网通信,需要先把电脑连接网络,我们可以用双绞线、光纤、…...
使用策略模式优化多重if/else
一、为什么需要策略模式? 作为前端程序员,我们经常会遇到这样的场景,例如 进入一个营销活动页面,会根据后端下发的不同 type ,前端页面展示不同的弹窗。 async getMainData() {try {const res await activityQuery()…...
逆强化学习
1.逆强化学习的理论框架 1.teacher的行为被定义成best 2.学习的网络有两个,actor和reward 3.每次迭代中通过比较actor与teacher的行为来更新reward function,基于新的reward function来更新actor使得actor获得的reward最大。 loss的设计相当于一个排序问…...
postgresql新特性之Merge
postgresql新特性之Merge 创建测试表测试案例 创建测试表 create table cps.public.test(id integer primary key,balance numeric,status varchar(1));测试案例 官网介绍 merge into test t using ( select 1 id,0 balance,Y status) s on(t.id s.id) -- 当匹配上了,statu…...
wordpress 修改端口/免费域名申请网站
1、文档如何存储1.1 分片与路由当索引一个文档的时候,文档会被存储到一个主分片中。Elasticsearch 如何知道一个文档应该存放到哪个分片中呢?当我们创建文档时,它如何决定这个文档应当被存储在分片 1 还是分片 2 中呢?首先这肯定不…...
网站开发源代码/新的网站怎么推广
欢迎学习交流!!! 持续更新中… 文章目录预加载什么是预加载为什么要用预加载预加载的实现懒加载什么是懒加载为什么要用懒加载懒加载的实现懒加载优化预加载和懒加载的比较预加载 什么是预加载 资源预加载是另一个性能优化技术,我…...
html5 网站布局应用教程/百度手机助手app
在这个新闻客户端,我们可以看到有一个轮播页面,在这个项目中,用Handler和一个定时器来做更容易一些, 我们定义一个Handler: private Handler mHandler; 定时器的代码如下: // 自动轮播条显示if (mHandle…...
哪个网站是做包装材料珍珠棉包管/免费直链平台
问题描述:本人的项目是用Maven管理,而且用到了servlet3.0的技术,但是项目中用到servlet3.0的地方,总提示找不到类中的方法。很奇怪,在网上找到好多解决办法,综合一下终于解决了。现将经验分享给大家。 前提…...
wordpress取消自动分页/seo点石论坛
Map 接口实现类的特点 Map 与 Collection 并列存在。用于保存具有映射关系的数据:Key-ValueMap 中的 key 和 value 可以是任何引用类型的数据,会封装到 HashMap$Node 对象中Map 中的 key 不允许重复,原因和 HashSet 一样,前面分析…...
自己做的网站打开显示很慢/百度竞价推广登录入口
文章目录MySQL行锁和表锁表锁行锁注意事项间隙锁总结MySQL行锁和表锁 MySQL常用引擎有MyISAM和InnoDB,而InnoDB是mysql默认的引擎。MyISAM不支持行锁,而InnoDB支持行锁和表锁。 MyISAM在执行select时,会自动给所有涉及的表加读锁,…...