卷积神经网络实现咖啡豆分类 - P7
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
- 🚀 文章来源:K同学的学习圈子
目录
- 环境
- 步骤
- 环境设置
- 包引用
- 全局设备对象
- 数据准备
- 查看图像的信息
- 制作数据集
- 模型设计
- 手动搭建的vgg16网络
- 精简后的咖啡豆识别网络
- 模型训练
- 编写训练函数
- 编写测试函数
- 开始训练
- 展示训练过程
- 模型效果展示
- 总结与心得体会
环境
- 系统: Linux
- 语言: Python3.8.10
- 深度学习框架: Pytorch2.0.0+cu118
- 显卡:A5000 24G
步骤
环境设置
包引用
import torch
import torch.nn as nn # 网络
import torch.optim as optim # 优化器
from torch.utils.data import DataLoader, random_split # 数据集划分
from torchvision import datasets, transforms # 数据集加载,转换import pathlib, random, copy # 文件夹遍历,实现模型深拷贝
from PIL import Image # python自带的图像类
import matplotlib.pyplot as plt # 图表
import numpy as np
from torchinfo import summary # 打印模型参数
全局设备对象
方便将模型和数据统一拷贝到目标设备中
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
数据准备
查看图像的信息
data_path = 'coffee_data'
data_lib = pathlib.Path(data_path)
coffee_images = list(data_lib.glob('*/*'))# 打印5张图像的信息
for _ in range(5):image = random.choice(coffee_images)print(np.array(Image.open(str(image))).shape)
通过打印的信息,可以看出图像的尺寸都是224x224的,这是一个CV经常使用的图像大小,所以后面我就不使用Resize来缩放图像了。
# 打印20张图像粗略的看一下
plt.figure(figsize=(20, 4))
for i in range(20):plt.subplot(2, 10, i+1)plt.axis('off')image = random.choice(coffee_images) # 随机选出一个图像plt.title(image.parts[-2]) # 通过glob对象取出它的文件夹名称,也就是分类名plt.imshow(Image.open(str(image))) # 展示
通过展示,对数据集内的图像有个大概的了解
制作数据集
先编写数据的预处理过程,用来使用pytorch的api加载文件夹中的图像
transform = transforms.Compose([transforms.ToTensor(), # 先把图像转成张量transforms.Normalize( # 对像素值做归一化,将数据范围弄到-1,1mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],),
])
加载文件夹
dataset = datasets.ImageFolder(data_path, transform=transform)
从数据中取所有的分类名
class_names = [k for k in dataset.class_to_idx]
print(class_names)
将数据集划分出训练集和验证集
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_sizetrain_dataset, test_dataset = random_split(dataset, [train_size, test_size])
将数据集按划分成批次,以便使用小批量梯度下降
batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
模型设计
在一开始时,直接手动创建了Vgg-16网络,发现少数几个迭代后模型就收敛了,于是开始精简模型。
手动搭建的vgg16网络
class Vgg16(nn.Module):def __init__(self, num_classes):super().__init__()self.block1 = nn.Sequential(nn.Conv2d(3, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),)self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2),)self.block3 = nn.Sequential(nn.Conv2d(128, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(2),)self.block4 = nn.Sequential(nn.Conv2d(256, 512, 3, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 512, 3, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 512, 3, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2),)self.block5 = nn.Sequential(nn.Conv2d(512, 512, 3, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 512, 3, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 512, 3, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2),)self.pool = nn.AdaptiveAvgPool2d(7)self.classifier = nn.Sequential(nn.Linear(7*7*512, 4096),nn.Dropout(0.5),nn.ReLU(),nn.Linear(4096, 4096),nn.Dropout(0.5),nn.ReLU(),nn.Linear(4096, num_classes),)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.pool(x)x = x.view(x.size(0),-1)x = self.classifier(x)return x
vgg = Vgg16(len(class_names)).to(device)
summary(vgg, input_size=(32, 3, 224, 224))
通过模型结构的打印可以发现,VGG-16网络共有134285380个可训练参数(我加了BatchNorm,和官方的比会稍微多出一些),参数量非常巨大,对于咖啡豆识别这种小场景,这么多可训练参数肯定浪费,于是对原始的VGG-16网络结构进行精简。
精简后的咖啡豆识别网络
class Network(nn.Module):def __init__(self, num_classes):super().__init__()self.block1 = nn.Sequential(nn.Conv2d(3, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),)self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2),)self.block3 = nn.Sequential(nn.Conv2d(128, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64, 64, 3, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),)self.pool = nn.AdaptiveAvgPool2d(7),self.classifier = nn.Sequential(nn.Linear(7*7*64, 64),nn.Dropout(0.4),nn.ReLU(),nn.Linear(64, num_classes))def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.pool(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x
model = Network(len(class_names)).to(device)
summary(model, input_size=(32, 3, 224, 224))
可以看到精简后的网络模型参数量还不到原来的1/10,但是其在测试集上的正确率依然能够达到100%!
模型训练
编写训练函数
def train(train_loader, model, loss_fn, optimizer):size = len(train_loader.dataset)num_batches = len(train_loader)train_loss, train_acc = 0, 0for x, y in train_loader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss /= num_batchestrain_acc /= sizereturn train_loss, train_acc
编写测试函数
def test(test_loader, model, loss_fn):size = len(test_loader.dataset)num_batches = len(test_loader)test_loss, test_acc = 0, 0for x, y in test_loader:x, y = x.to(device), y.to(device)pred = model(x)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchestest_acc /= sizereturn test_loss, test_acc
开始训练
首先定义损失函数,优化器设置学习率,这里我们再弄一个学习率的衰减,再加上总的迭代次数,最佳模型的保存位置
epochs = 30
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.92**(epoch//2))
best_model_path = 'best_coffee_model.pth'
然后编写训练+测试的循环,并记录训练过程的数据
best_acc = 0
train_loss, train_acc = [], []
test_loss, test_acc = [], []
for epoch in epochs:model.train()epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)scheduler.step()model.eval()with torch.no_grad();epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)train_loss.append(epoch_train_loss)train_acc.append(epoch_train_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)lr = optimizer.state_dict()['param_groups'][0]['lr']if best_acc < epoch_test_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)print(f"Epoch: {epoch+1}, Lr:{lr}, TrainAcc: {epoch_train_acc*100:.1f}, TrainLoss: {epoch_train_loss:.3f}, TestAcc: {epoch_test_acc*100:.1f}, TestLoss: {epoch_test_loss:.3f}")print(f"Saving Best Model with Accuracy: {best_acc*100:.1f} to {best_model_path}")
torch.save(best_model.state_dict(), best_model_path)
print('Done')
可以看出,模型在测试集上的正确率最高达到了100%
展示训练过程
epoch_ranges = range(epochs)
plt.figure(figsize=(20,6))
plt.subplot(121)
plt.plot(epoch_ranges, train_loss, label='train loss')
plt.plot(epoch_ranges, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.figure(figsize=(20,6))
plt.subplot(122)
plt.plot(epoch_ranges, train_acc, label='train accuracy')
plt.plot(epoch_ranges, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')
模型效果展示
model.load_state_dict(torch.load(best_model_path))
model.to(device)
model.eval()plt.figure(figsize=(20,4))
for i in range(20):plt.subplot(2, 10, i+1)plt.axis('off')image = random.choice(coffee_images)input = transform(Image.open(str(image))).to(device).unsqueeze(0)pred = model(input)plt.title(f'T:{image.parts[-2]}, P:{class_names[pred.argmax()]}')plt.imshow(Image.open(str(image)))
通过结果可以看出,确实是所有的咖啡豆都正确的识别了。
总结与心得体会
- 因为目前网络还是很快就收敛到一个很高的水平,所以应该还有很大的精简的空间,但是可能会稍微牺牲一些正确率。
- 模型的选取要根据实际任务来确定,像咖啡豆种类识别这种任务,使用VGG-16太浪费了。
- 在精简的过程中,没有感觉到训练速度有明显的变化 ,说明参数量和训练速度并没有直接的相关关系。
- 连续多层参数一样的卷积操作好像比只用一层效果要好。
相关文章:
卷积神经网络实现咖啡豆分类 - P7
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 目录 环境步骤环境设置包引用全局设备对象 数据准备查看图像的信息制作数据集 模型设…...
C++之默认与自定义构造函数问题(二百一十七)
简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…...
Docker从认识到实践再到底层原理(五)|Docker镜像
前言 那么这里博主先安利一些干货满满的专栏了! 首先是博主的高质量博客的汇总,这个专栏里面的博客,都是博主最最用心写的一部分,干货满满,希望对大家有帮助。 高质量博客汇总 然后就是博主最近最花时间的一个专栏…...
【Flowable】任务监听器(五)
前言 之前有需要使用到Flowable,鉴于网上的资料不是很多也不是很全也是捣鼓了半天,因此争取能在这里简单分享一下经验,帮助有需要的朋友,也非常欢迎大家指出不足的地方。 一、监听器 在Flowable中,我们可以使用监听…...
spring-kafka中ContainerProperties.AckMode详解
近期,我们线上遇到了一个性能问题,几乎快引起线上故障,后来仅仅是修改了一行代码,性能就提升了几十倍。一行代码几十倍,数据听起来很夸张,不过这是真实的数据,线上错误的配置的确有可能导致性能…...
【rpc】Dubbo和Zookeeper结合使用,它们的作用与联系(通俗易懂,一文理解)
目录 Dubbo是什么? 把系统模块变成分布式,有哪些好处,本来能在一台机子上运行,为什么还要远程调用 Zookeeper是什么? 它们进行配合使用时,之间的关系 服务注册 服务发现 动态地址管理 Dubbo是…...
ChatGPT的未来
随着人工智能的快速发展,ChatGPT作为一种自然语言生成模型,在各个领域都展现出了巨大的潜力。它不仅可以用于日常对话、创意助手和知识查询,还可以应用于教育、医疗、商业等各个领域,为人们带来更多便利和创新。 在教育领域&#…...
Pytorch模型转ONNX部署
开始以为会很困难,但是其实非常方便,下边分两步走:1. pytorch模型转onnx;2. 使用onnx进行inference 0. 准备工作 0.1 安装onnx 安装onnx和onnxruntime,onnx貌似是个环境。。倒是没有直接使用,onnxruntim…...
k8s优雅停服
在应用程序的整个生命周期中,正在运行的 pod 会由于多种原因而终止。在某些情况下,Kubernetes 会因用户输入(例如更新或删除 Deployment 时)而终止 pod。在其他情况下,Kubernetes 需要释放给定节点上的资源时会终止 po…...
面试题五:computed的使用
题记 大部分的工作中使用computed的频次很低的,所以今天拿出来一文对于computed进行详细的介绍,因为Vue的灵魂之一就是computed。 模板内的表达式非常便利,但是设计它们的初衷是用于简单运算的。在模板中放入太多的逻辑会让模板过重且难以维护…...
完美的分布式监控系统 Prometheus与优雅的开源可视化平台 Grafana
1、之间的关系 prometheus与grafana之间是相辅相成的关系。简而言之Grafana作为可视化的平台,平台的数据从Prometheus中取到来进行仪表盘的展示。而Prometheus这源源不断的给Grafana提供数据的支持。 Prometheus是一个开源的系统监控和报警系统,能够监…...
黑马JVM总结(九)
(1)StringTable_调优1 我们知道StringTable底层是一个哈希表,哈希表的性能是跟它的大小相关的,如果哈希表这个桶的个数比较多,元素相对分散,哈希碰撞的几率就会减少,查找的速度较快,…...
如何使用 RunwayML 进行创意 AI 创作
标题:如何使用 RunwayML 进行创意 AI 创作 介绍 RunwayML 是一个基于浏览器的人工智能创作工具,可让用户使用各种 AI 功能来生成图像、视频、音乐、文字和其他创意内容。RunwayML 的功能包括: * 图像生成:使用生成式对抗网络 (…...
【css】能被4整除 css :class,判断一个数能否被另外一个数整除,余数
判断一个数能否被另外一个数整除 一个数能被4整除的表达式可以表示为:num%40,其中,num为待判断的数,% 为取模运算符,为等于运算符。这个表达式的意思是,如果num除以4的余数为0,则返回true&…...
ChatGPT与日本首相交流核废水事件-精准Prompt...
了解更多请点击:ChatGPT与日本首相交流核废水事件-精准Prompt...https://mp.weixin.qq.com/s?__bizMzg2NDY3NjY5NA&mid2247490070&idx1&snebdc608acd419bb3e71ca46acee04890&chksmce64e42ff9136d39743d16059e2c9509cc799a7b15e8f4d4f71caa25968554…...
关于 firefox 不能访问 http 的解决
情景: 我在虚拟机 192.168.x.111 上配置了 DNS 服务器,在 kali 上设置 192.168.x.111 为 DNS 服务器后,使用 firefox 地址栏搜索域名 www.xxx.com ,访问在 192.168.x.111 搭建的网站,本来经 192.168.x.111 DNS 服务器解…...
68、Spring Data JPA 的 方法名关键字查询
★ 方法名关键字查询(全自动) (1)继承 CrudRepository 接口 的 DAO 组件可按特定规则来定义查询方法,只要这些查询方法的 方法名 遵守特定的规则,Spring Data 将会自动为这些方法生成 查询语句、提供 方法…...
Brother CNC联网数采集和远程控制
兄弟CNC IP地址设定参考:https://www.sohu.com/a/544461221_121353733没有能力写代码的兄弟可以提前下载好网络调试助手NetAssist,这样就不用写代码来测试连接CNC了。 以上是网络调试助手抓取CNC的产出命令,结果有多个行string需要自行解析&…...
Jenkins 编译 Maven 项目提示错误 version 17
在最近使用集成工具的时候,对项目进行编译提示下面的错误信息: maven-compiler-plugin:3.11.0:compile (default-compile) on project mq-service: Fatal error compiling: error: release version 17 not supported 问题和解决 上面提示的错误信息原…...
数据结构——排序算法——堆排序
堆排序过程如下: 1.用数列构建出一个大顶堆,取出堆顶的数字; 2.调整剩余的数字,构建出新的大顶堆,再次取出堆顶的数字; 3.循环往复,完成整个排序。 构建大顶堆有两种方式: 1.从 0 开…...
【Spring事务底层实现原理】
Transactional注解 Spring使用了TransactionInterceptor拦截器,该拦截器主要负责事务的管理,包括开启、提交、回滚等操作。当在方法上添加Transactional注解时,Spring会在AOP框架中对该方法进行拦截,TransactionInterceptor会在该…...
docker快速安装redis,mysql,minio,nacos等常用软件【持续更新】
redis ①拉取镜像 docker pull redis② 创建容器 docker run -d --name redis --restartalways -p 6379:6379 redis --requirepass "PASSWORD"–requirepass “输入你的redis密码” nacos ①:docker拉取镜像 docker pull nacos/nacos-server:1.2.0②…...
SCRUM产品负责人(CSPO)认证培训课程
课程简介 Scrum是目前运用最为广泛的敏捷开发方法,是一个轻量级的项目管理和产品研发管理框架。产品负责人是Scrum的三个角色之一,产品负责人在Scrum产品开发当中扮演舵手的角色,他决定产品的愿景、路线图以及投资回报,他需要回答…...
python连接mysql数据库的练习
一、导入pandas内置的sqlite3模块,连接的信息:ip地址是本机, 端口号port 是3306, 用户user是root, 密码password是123456, 数据库database是lambda-xiaozhang import pymysql# 打开数据库连接,参数1:主机名或IP;参数…...
扩散模型在图像生成中的应用:从真实样例到逼真图像的奇妙转变
一、扩散模型 扩散模型的起源可以追溯到热力学中的扩散过程。热力学中的扩散过程是指物质从高浓度往低浓度的地方流动,最终达到一种动态的平衡。这个过程就是一个扩散过程。 在深度学习领域中,扩散模型(diffusion models)是深度生…...
Windows 打包 Docker 提示环境错误: no DOCKER_HOST environment variable
这个问题应该还是比较常见的。 [ERROR] Failed to execute goal io.fabric8:docker-maven-plugin:0.40.2:build (default) on project mq-service: Execution default of goal io.fabric8:docker-maven-plugin:0.40.2:build failed: No <dockerHost> given, no DOCKER_H…...
2023.9.8 基于传输层协议 UDP 和 TCP 编写网络通信程序
目录 UDP 基于 UDP 编写网络通信程序 服务器代码 客户端代码 TCP 基于 TCP 编写网络通信程序 服务器代码 客户端代码 IDEA 打开 支持多客户端模式 UDP 特点: 无连接性:发送端和接收端不需要建立连接也可相互通信,且每个 UDP 数据包都…...
单例模式,适用于对象唯一的情景(设计模式与开发实践 P4)
文章目录 单例模式实现代理单例惰性单例 上一章后续的内容是关于 JS 函数闭包的,考虑很多读者已经有了闭包基础或者希望通过实战理解,遂跳过上一章直接开始设计模式篇~ 需要注意的是,代码部分仅供参考,主要关注的内容是…...
C语言实现三子棋游戏(详解)
目录 引言: 1.游戏规则: 2.实现步骤: 2.1实现菜单: 2.2创建棋盘并初始化: 2.3绘制棋盘: 2.4玩家落子: 2.5电脑落子: 2.6判断胜负: 3.源码: 结语&…...
javaee之黑马乐优商城3
异步查询工具axios(儿所以时) vue官方推荐的ajax请求框架 新增品牌页面 如何找到上面这个页面 下面这个页面里面的新增商品弹窗 上面就是请求路径与请求方式 那么请求参数是什么? brand对象,外加商品分类的id数组cids (这里其实不止就是添加…...
西安网站开发工程师/武汉seo首页
13-Figma-组件管理 常见操作 创建组件,选择,点击顶部创建多个组件,框选多个,点击顶部组件使用,对组件进行复制,就创建了组件的实例实例跳到模板,右键-》转到组件模板所有组件如何管理…...
什么网站有做册子版/学生没钱怎么开网店
创业公司怎样才能被大公司收购?对于一家创业公司来说,虽然自己坚持下去也是个不错的选择,不过要是被像苹果或Google这样最具价值的公司收购,那将是一件无比激动人心的事,那怎么做才能受到它们的青睐,吸引它…...
青岛模版网站建设/简述什么是seo及seo的作用
副标题——别把技术问题转化为人际问题(作者:孙继滨)【项目经理之修炼】 全文索引 刚刚进入项目时,由于权威还没有树立,人际关系尚浅,经常会有组员不把你当回事儿。于是: 你要求组员写周报&…...
在哪个网站做兼职淘宝客服/互联网推广是做什么的
pandas 新增sheet,不覆盖原来已经保存的sheet(亲测管用) 口袋里的小小哥 2020-06-02 14:49:02 2759 收藏 9 分类专栏: python 文章标签: pandas 新增sheet不覆盖原数据 数据分析 版权 #以前的sheet数据很重要&#…...
长沙百度做网站多少钱/推广软件赚钱的平台
手机上有好用的时钟APP,时钟软件哪个好用? 每个人的手机屏幕都有着自己独特的设计,而时钟软件能够为用户的手机再增添一些风采和魅力,更有专为来年人设计的大字体桌面时钟,更加的清晰直观,各种特色的桌面时钟app中都有…...
网站开发使用的软件/google搜索下载
随着SPA应用程序的普及,前端开发中对路由的操作越来越离不开了,目前基本上前端框架中的路由都是基于hash和history模式,在大量使用之余我们也理应该去窥探一下它他们的真面hashhash模式的原理其实非常简单,它提供了一个onhashchan…...