当前位置: 首页 > news >正文

【计算机视觉 | Pytorch】timm 包的具体介绍和图像分类案例(含源代码)

一、具体介绍

timm 是一个 PyTorch 原生实现的计算机视觉模型库。它提供了预训练模型和各种网络组件,可以用于各种计算机视觉任务,例如图像分类、物体检测、语义分割等等。

timm 的特点如下:

  1. PyTorch 原生实现:timm 的实现方式与 PyTorch 高度契合,开发者可以方便地使用 PyTorchAPI 进行模型训练和部署。
  2. 轻量级的设计:timm 的设计以轻量化为基础,根据不同的计算机视觉任务,提供了多种轻量级的网络结构。
  3. 大量的预训练模型:timm 提供了大量的预训练模型,可以直接用于各种计算机视觉任务。
  4. 多种模型组件:timm 提供了各种模型组件,如注意力模块、正则化模块、激活函数等等,这些模块都可以方便地插入到自己的模型中。
  5. 高效的代码实现:timm 的代码实现高效并且易于使用。

需要注意的是,timm 是一个社区驱动的项目,它由计算机视觉领域的专家共同开发和维护。在使用时需要遵循相关的使用协议。

二、图像分类案例

下面以使用 timm 实现图像分类任务为例,进行简单的介绍。

2.1 安装 timm 包

!pip install timm

2.2 导入相关模块,读取数据集

import torch
import torch.nn as nn
import timm
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10# 数据增强
train_transforms = transforms.Compose([transforms.RandomCrop(size=32, padding=4),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(degrees=15),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])test_transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])# 数据集
train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transforms)
test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transforms)# DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

导入相关模块,其中 timmtorchvision.datasets.CIFAR10 需要分别安装 timmtorchvision 包。

定义数据增强的方式,其中训练集和测试集分别使用不同的增强方式,并且对图像进行了归一化处理。transforms.Compose() 可以将各种操作打包成一个 transform 操作流,transforms.ToTensor() 将图像转化为 tensor 格式,transforms.Normalize() 将图像进行标准化处理。

使用自带的 CIFAR10 数据集,设置 train=True 定义训练集,设置 train=False 定义测试集。数据集会自动下载到指定的 root 路径下,并进行数据增强操作。

使用 torch.utils.data.DataLoader 定义数据加载器,将数据集包装成一个高效的可迭代对象,其中 batch_size 定义批次大小,shuffle 定义是否对数据进行随机洗牌,num_workers 定义使用多少个 worker 来加载数据。

在这里插入图片描述

2.3 定义模型

# 加载预训练模型
model = timm.create_model('resnet18', pretrained=True)# 修改分类器
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))

这里使用 timm.create_model() 函数来创建一个预训练模型,其中参数 resnet18 定义了使用的模型架构,参数 pretrained = True 表示要使用预训练权重。

这里修改了模型的分类器,首先使用 model.fc.in_features 获取模型 fc 层的输入特征数,然后使用 nn.Linear() 重新定义了一个 nn.Linear 层,输入为上一层的输出特征数,输出为类别数(即 len(train_dataset.classes))。这里直接使用了数据集类别数来定义输出层,以适配不同分类任务的需求。

在这里插入图片描述
在这里,我们使用了 timm 中的 ResNet18 模型,并将其修改为我们需要的分类器,同时在创建模型时,设置参数 pretrained=True 来加载预训练权重。

2.4 定义损失函数和优化器

# 损失函数
criterion = nn.CrossEntropyLoss()# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

在深度学习中,损失函数是评估模型预测结果与真实标签之间差异的一种指标,常用于模型训练过程中。nn.CrossEntropyLoss() 是一个常用的损失函数,适用于多分类问题。

优化器用于更新模型参数以使损失函数最小化。在这里,我们使用了随机梯度下降法(SGD)优化器,以控制模型权重的变化。通过 model.parameters() 指定需要优化的参数,lr 定义了学习率,表示每次迭代时参数必须更新的量的大小,momentum 则是添加上次迭代更新值的一部分到这一次的更新值中,以减小参数更新的方差,稳定训练过程。

2.5 训练模型

num_epochs = 10for epoch in range(num_epochs):# 训练model.train()for images, labels in train_loader:# 前向传播outputs = model(images)# 计算损失loss = criterion(outputs, labels)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 测试model.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Epoch {} Accuracy: {:.2f}%'.format(epoch+1, 100*correct/total))

这段代码是模型训练和测试的循环。num_epochs 定义了循环的次数,每次循环表示一个训练周期。

在训练阶段,首先将模型切换到训练模式,然后使用 train_loader 迭代地读取训练集数据,进行前向传播、计算损失、反向传播和优化器更新等操作。

在测试阶段,模型切换到评估模式,然后使用 test_loader 读取测试集数据,进行前向传播和计算模型预测结果,使用预测结果和真实标签进行准确率计算,并输出每个训练周期的准确率。

其中,torch.max() 函数用于返回每行中最大值及其索引,total 记录了总的测试样本数,correct 记录了正确分类的样本数,最后计算准确率并输出。

输出结果为:

在这里插入图片描述

相关文章:

【计算机视觉 | Pytorch】timm 包的具体介绍和图像分类案例(含源代码)

一、具体介绍 timm 是一个 PyTorch 原生实现的计算机视觉模型库。它提供了预训练模型和各种网络组件,可以用于各种计算机视觉任务,例如图像分类、物体检测、语义分割等等。 timm 的特点如下: PyTorch 原生实现:timm 的实现方式…...

轻博客Plume的搭建

什么是 Plume ? Plume 是一个基于 ActivityPub 的联合博客引擎。它是用 Rust 编写的,带有 Rocket 框架,以及 Diesel 与数据库交互。前端使用 Ructe模板、WASM 和SCSS。 反向代理 假设我们实际访问地址为: https://plume.laosu.ml…...

机器人关节电机PWM

脉冲宽度调制(Pulse width modulation,PWM)技术。一种模拟控制方式 机器人关节电机的控制通常使用PWM(脉冲宽度调制)技术。PWM是一种用于控制电子设备的技术,通过控制高电平和低电平之间的时间比例,实现对电子设备的控制。在机器人关节电机中,PWM信号可以控制电机的…...

MPU6050详解(含源码)

前言:MPU6050是一款强大的六轴传感器,需要理解MPU6050首先得有IIC的基础,MPU6050 内部整合了 3 轴陀螺仪和 3 轴加速度传感器,并且含有一个第二 IIC 接口,可用于连接外部磁力传感器,内部有硬件算法支持. 1…...

Vue入门学习笔记:TodoList(三):实例中的数据、事件和方法

目录: Vue入门学习笔记:TodoList(一):HelloWorld Vue入门学习笔记:TodoList(二):挂载点、模板、实例 Vue入门学习笔记:TodoList(三)&a…...

怎么找到引发回流的JavaScript代码?

要找到引发回流的JavaScript代码,可以使用浏览器的开发者工具中的性能分析器。不同的浏览器有不同的名称和位置,例如Google Chrome的开发者工具中的性能分析器被称为Performance,Firefox的开发者工具中的性能分析器被称为Profiler。 以下是在…...

未来广告策划,转型还是淘汰?

在广告行业呆了十来年了,最近我越来越感觉到广告行业真的是一个需要与时俱进,并且应用场景非常广泛的一个专业。 而且由于这是一个需要创意能力的行业,所以对比于重复性容易被机器以及人工智能所代替的岗位行业来说,广告的可替代…...

【vscode远程开发】使用SSH远程连接服务器 「内网穿透」

文章目录 前言视频教程1、安装OpenSSH2、vscode配置ssh3. 局域网测试连接远程服务器4. 公网远程连接4.1 ubuntu安装cpolar内网穿透4.2 创建隧道映射4.3 测试公网远程连接 5. 配置固定TCP端口地址5.1 保留一个固定TCP端口地址5.2 配置固定TCP端口地址5.3 测试固定公网地址远程 转…...

七天从零实现Web框架Gee - 扩展

到这里前七天的任务已经完成,但我们可以对Gee框架进行一些扩展 补充HTTP请求方法 原作者只实现了 GET, POST 路由添加,其他的 PUT, DELETE 等标准 HTTP 方法未实现,实现方法也很简单,只需在gee.go中增加如下代码 // PUT define…...

什么是土壤水分传感器

土壤水分传感器又称土壤湿度传感器由不锈钢探针和防水探头构成,可长期埋设于土壤和堤坝内使用,对表层和深层土壤进行墒情的定点监测和在线测量。与数据采集器配合使用,可作为水分定点监测或移动测量的工具(即农田墒情检测仪&#…...

月薪17k需要什么水平?98年测试员的面试全过程…

我的情况 大概介绍一下个人情况,男,本科,三年多测试工作经验,懂python,会写脚本,会selenium,会性能,然而到今天都没有收到一份offer!从年后就开始准备简历,年…...

知了汇智:坚持发展产教融合,做好高校、人才与企业之间的桥梁

6月将正式迎来高校毕业季,大学生就业是聚焦全社会关注的头等大事。5月9日,成都知了汇智科技有限公司(以下简称“知了汇智”)组织开展“深化产教融合、聚焦人才培养”的主题座谈会议,联动高校与合作企业参加&#xff0c…...

MyBatis缓存-一级缓存--二级缓存的非常详细的介绍

目录 MyBatis-缓存-提高检索效率的利器 缓存-官方文档 一级缓存 基本说明 一级缓存原理图 代码演示 修改MonsterMapperTest.java, 增加测试方法 结果 debug 一级缓存执行流程 一级缓存失效分析 关闭sqlSession会话后 , 一级缓存失效 如果执行sqlSession.clearCache(…...

macOS Ventura 13.4 RC2(22F63)发布

系统介绍 根据黑果魏叔官网提供:5 月 12 日消息,苹果今天面向开发人员,发布了 macOS Ventura 13.4 的第 2 个候选 RC 版本(内部版本号 22F63),距离上个候选版本相隔数天时间。 macOS Ventura 带来了台前调…...

【为什么可以相信一个HTTPS网站】

解决信用,仅仅有加密和解密是不够的。加密解密解决的只是传输链路的安全问题,相当于两个人说话不被窃听。可以类比成你现在生活 的世界——货币的信用,是由政府在背后支撑的;购房贷款的信用,是由银行在背后支撑的&…...

4.进阶篇

目录 一、按照测试对象划分 1.界面测试(UI测试) 界面测试的常见错误: 2.可靠性测试 3.容错性测试 4.文档测试 5.兼容性测试 6.易用性 7.安装卸载测试 8.安全性测试 9.性能测试 10.内存泄漏 二、按照是否查看代码 1.黑盒测试 2.…...

conda init

在输入conda activate 的时候出现报错: 解决: "需要使用 conda init 进行初始化" 的错误通常是由于你的系统环境缺少 conda 的初始化脚本所致。当你尝试在终端中执行 conda activate 命令时,会出现此错误提示。 要解决这个问题,可以通过以下步骤进行操作: 打…...

Elasticsearch(二)

Clasticsearch(二) DSL查询语法 文档 文档:https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html 常见查询类型包括: 查询所有:查询出所有数据,一般测试用。如&#xff1a…...

工业视觉检测的8个技术优势

工业4.0时代,自动化生产线成为了这个时代的主旋律,而工业视觉检测技术也成为其中亮眼的表现,其机器视觉技术为设备提供了智慧的双眼,让自动化的脚步得以加速! 在实际的生产应用中,视觉技术方案往往先被着手…...

16 KVM虚拟机配置-其他常见配置项

文章目录 16 KVM虚拟机配置-其他常见配置项16.1 概述16.2 元素介绍16.3 配置示例 16 KVM虚拟机配置-其他常见配置项 16.1 概述 除系统资源和虚拟设备外,XML配置文件还需要配置一些其他元素,本节介绍这些元素的配置方法。 16.2 元素介绍 iothreads&…...

(转载)从0开始学matlab(第1天)—变量和数组

MATLAB 程序的基本数据单元是数组。一个数组是以行和列组织起来的数据集合,并且拥有一个数组名。数组中的单个数据是可以被访问的,访问的方法是数组名后带一个括号,括号内是这个数据所对应行标和列标。标量在 MATLAB 中也被当作数组来处理——…...

Linux命令·wget

Linux系统中的wget是一个下载文件的工具,它用在命令行下。对于Linux用户是必不可少的工具,我们经常要下载一些软件或从远程服务器恢复备份到本地服务器。wget支持HTTP,HTTPS和FTP协议,可以使用HTTP代理。所谓的自动下载是指&#…...

API网关简介|TaobaoAPI接入

API网关是什么 在日常工作中,不同的场合下,我们可能听说过很多次网关这个名称,这里说的网关特指API网关(API Gataway)。字面意思是指将所有API的调用统一接入API网关层,由网关层负责接入和输出。 那么在什…...

OJ练习第103题——最大矩形

最大矩形 力扣链接:85. 最大矩形 题目描述 给定一个仅包含 0 和 1 、大小为 rows x cols 的二维二进制矩阵,找出只包含 1 的最大矩形,并返回其面积。 示例 输入:matrix [[“1”,“0”,“1”,“0”,“0”],[“1”,“0”,“1”…...

JavaScript实现输入年份判断是否为闰年的代码

以下为实现输入年份判断是否为闰年的程序代码和运行截图 目录 前言 一、输入年份判断是否为闰年 1.1 运行流程及思想 1.2 代码段 1.3 JavaScript语句代码 1.4 运行截图 前言 1.若有选择,您可以在目录里进行快速查找; 2.本博文代码可以根据题目要…...

LiangGaRy-学习笔记-Day12

1、作业回顾 1.1、判断磁盘利用率 要求: 判断磁盘的使用率,如果超过了90%就警告 [rootNode1 sh]# vim disk_check.sh #!/bin/bash #Author By LiangGaRy #2023年5月9日 #Usage:检测硬盘的使用率 ########################################### #定义一…...

LayUI中弹出层select动态回显设置及子页面刷新父页面Table数据方法

...

浅谈Hutool工具类

一、Hutool简介 Hutool是一个Java工具类库,它封装了很多常用的Java工具类,如加密解密、文件操作、日期时间处理、Http客户端等。它的目标是让Java开发变得更加简单、高效。 二、Hutool的特点 高效:提供了很多高效的工具类和方法。 简单&…...

Mac终端代理

1.打开代理查看代理端口号 打开设置,点击网络,点击详细信息,点击代理查看代理端口号。 2.修改环境变量 1)终端输入下面命令 vim .zshrc 2)在.zshrc文件里添加下面两段内容(注意:7980为端口号…...

Git Clone 报错 `SSL certificate problem: unable to get local issuer certificate`

如果您在尝试克隆Git存储库时得到 “SSL certificate problem: unable to get local issuer certificate” 的错误,这意味着Git无法验证远程存储库的SSL证书。如果SSL证书是自签名的,或者SSL证书链有问题,就会发生这种情况。 $ git clone https://githu…...