【计算机视觉 | Pytorch】timm 包的具体介绍和图像分类案例(含源代码)
一、具体介绍
timm 是一个 PyTorch 原生实现的计算机视觉模型库。它提供了预训练模型和各种网络组件,可以用于各种计算机视觉任务,例如图像分类、物体检测、语义分割等等。
timm 的特点如下:
- PyTorch 原生实现:timm 的实现方式与 PyTorch 高度契合,开发者可以方便地使用 PyTorch 的 API 进行模型训练和部署。
- 轻量级的设计:timm 的设计以轻量化为基础,根据不同的计算机视觉任务,提供了多种轻量级的网络结构。
- 大量的预训练模型:timm 提供了大量的预训练模型,可以直接用于各种计算机视觉任务。
- 多种模型组件:timm 提供了各种模型组件,如注意力模块、正则化模块、激活函数等等,这些模块都可以方便地插入到自己的模型中。
- 高效的代码实现:timm 的代码实现高效并且易于使用。
需要注意的是,timm 是一个社区驱动的项目,它由计算机视觉领域的专家共同开发和维护。在使用时需要遵循相关的使用协议。
二、图像分类案例
下面以使用 timm 实现图像分类任务为例,进行简单的介绍。
2.1 安装 timm 包
!pip install timm
2.2 导入相关模块,读取数据集
import torch
import torch.nn as nn
import timm
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10# 数据增强
train_transforms = transforms.Compose([transforms.RandomCrop(size=32, padding=4),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomRotation(degrees=15),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])test_transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])# 数据集
train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transforms)
test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transforms)# DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
导入相关模块,其中 timm 和 torchvision.datasets.CIFAR10 需要分别安装 timm 和 torchvision 包。
定义数据增强的方式,其中训练集和测试集分别使用不同的增强方式,并且对图像进行了归一化处理。transforms.Compose() 可以将各种操作打包成一个 transform 操作流,transforms.ToTensor() 将图像转化为 tensor 格式,transforms.Normalize() 将图像进行标准化处理。
使用自带的 CIFAR10 数据集,设置 train=True 定义训练集,设置 train=False 定义测试集。数据集会自动下载到指定的 root 路径下,并进行数据增强操作。
使用 torch.utils.data.DataLoader 定义数据加载器,将数据集包装成一个高效的可迭代对象,其中 batch_size 定义批次大小,shuffle 定义是否对数据进行随机洗牌,num_workers 定义使用多少个 worker 来加载数据。

2.3 定义模型
# 加载预训练模型
model = timm.create_model('resnet18', pretrained=True)# 修改分类器
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))
这里使用 timm.create_model() 函数来创建一个预训练模型,其中参数 resnet18 定义了使用的模型架构,参数 pretrained = True 表示要使用预训练权重。
这里修改了模型的分类器,首先使用 model.fc.in_features 获取模型 fc 层的输入特征数,然后使用 nn.Linear() 重新定义了一个 nn.Linear 层,输入为上一层的输出特征数,输出为类别数(即 len(train_dataset.classes))。这里直接使用了数据集类别数来定义输出层,以适配不同分类任务的需求。

在这里,我们使用了 timm 中的 ResNet18 模型,并将其修改为我们需要的分类器,同时在创建模型时,设置参数 pretrained=True 来加载预训练权重。
2.4 定义损失函数和优化器
# 损失函数
criterion = nn.CrossEntropyLoss()# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
在深度学习中,损失函数是评估模型预测结果与真实标签之间差异的一种指标,常用于模型训练过程中。nn.CrossEntropyLoss() 是一个常用的损失函数,适用于多分类问题。
优化器用于更新模型参数以使损失函数最小化。在这里,我们使用了随机梯度下降法(SGD)优化器,以控制模型权重的变化。通过 model.parameters() 指定需要优化的参数,lr 定义了学习率,表示每次迭代时参数必须更新的量的大小,momentum 则是添加上次迭代更新值的一部分到这一次的更新值中,以减小参数更新的方差,稳定训练过程。
2.5 训练模型
num_epochs = 10for epoch in range(num_epochs):# 训练model.train()for images, labels in train_loader:# 前向传播outputs = model(images)# 计算损失loss = criterion(outputs, labels)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 测试model.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Epoch {} Accuracy: {:.2f}%'.format(epoch+1, 100*correct/total))
这段代码是模型训练和测试的循环。num_epochs 定义了循环的次数,每次循环表示一个训练周期。
在训练阶段,首先将模型切换到训练模式,然后使用 train_loader 迭代地读取训练集数据,进行前向传播、计算损失、反向传播和优化器更新等操作。
在测试阶段,模型切换到评估模式,然后使用 test_loader 读取测试集数据,进行前向传播和计算模型预测结果,使用预测结果和真实标签进行准确率计算,并输出每个训练周期的准确率。
其中,torch.max() 函数用于返回每行中最大值及其索引,total 记录了总的测试样本数,correct 记录了正确分类的样本数,最后计算准确率并输出。
输出结果为:

相关文章:
【计算机视觉 | Pytorch】timm 包的具体介绍和图像分类案例(含源代码)
一、具体介绍 timm 是一个 PyTorch 原生实现的计算机视觉模型库。它提供了预训练模型和各种网络组件,可以用于各种计算机视觉任务,例如图像分类、物体检测、语义分割等等。 timm 的特点如下: PyTorch 原生实现:timm 的实现方式…...
轻博客Plume的搭建
什么是 Plume ? Plume 是一个基于 ActivityPub 的联合博客引擎。它是用 Rust 编写的,带有 Rocket 框架,以及 Diesel 与数据库交互。前端使用 Ructe模板、WASM 和SCSS。 反向代理 假设我们实际访问地址为: https://plume.laosu.ml…...
机器人关节电机PWM
脉冲宽度调制(Pulse width modulation,PWM)技术。一种模拟控制方式 机器人关节电机的控制通常使用PWM(脉冲宽度调制)技术。PWM是一种用于控制电子设备的技术,通过控制高电平和低电平之间的时间比例,实现对电子设备的控制。在机器人关节电机中,PWM信号可以控制电机的…...
MPU6050详解(含源码)
前言:MPU6050是一款强大的六轴传感器,需要理解MPU6050首先得有IIC的基础,MPU6050 内部整合了 3 轴陀螺仪和 3 轴加速度传感器,并且含有一个第二 IIC 接口,可用于连接外部磁力传感器,内部有硬件算法支持. 1…...
Vue入门学习笔记:TodoList(三):实例中的数据、事件和方法
目录: Vue入门学习笔记:TodoList(一):HelloWorld Vue入门学习笔记:TodoList(二):挂载点、模板、实例 Vue入门学习笔记:TodoList(三)&a…...
怎么找到引发回流的JavaScript代码?
要找到引发回流的JavaScript代码,可以使用浏览器的开发者工具中的性能分析器。不同的浏览器有不同的名称和位置,例如Google Chrome的开发者工具中的性能分析器被称为Performance,Firefox的开发者工具中的性能分析器被称为Profiler。 以下是在…...
未来广告策划,转型还是淘汰?
在广告行业呆了十来年了,最近我越来越感觉到广告行业真的是一个需要与时俱进,并且应用场景非常广泛的一个专业。 而且由于这是一个需要创意能力的行业,所以对比于重复性容易被机器以及人工智能所代替的岗位行业来说,广告的可替代…...
【vscode远程开发】使用SSH远程连接服务器 「内网穿透」
文章目录 前言视频教程1、安装OpenSSH2、vscode配置ssh3. 局域网测试连接远程服务器4. 公网远程连接4.1 ubuntu安装cpolar内网穿透4.2 创建隧道映射4.3 测试公网远程连接 5. 配置固定TCP端口地址5.1 保留一个固定TCP端口地址5.2 配置固定TCP端口地址5.3 测试固定公网地址远程 转…...
七天从零实现Web框架Gee - 扩展
到这里前七天的任务已经完成,但我们可以对Gee框架进行一些扩展 补充HTTP请求方法 原作者只实现了 GET, POST 路由添加,其他的 PUT, DELETE 等标准 HTTP 方法未实现,实现方法也很简单,只需在gee.go中增加如下代码 // PUT define…...
什么是土壤水分传感器
土壤水分传感器又称土壤湿度传感器由不锈钢探针和防水探头构成,可长期埋设于土壤和堤坝内使用,对表层和深层土壤进行墒情的定点监测和在线测量。与数据采集器配合使用,可作为水分定点监测或移动测量的工具(即农田墒情检测仪&#…...
月薪17k需要什么水平?98年测试员的面试全过程…
我的情况 大概介绍一下个人情况,男,本科,三年多测试工作经验,懂python,会写脚本,会selenium,会性能,然而到今天都没有收到一份offer!从年后就开始准备简历,年…...
知了汇智:坚持发展产教融合,做好高校、人才与企业之间的桥梁
6月将正式迎来高校毕业季,大学生就业是聚焦全社会关注的头等大事。5月9日,成都知了汇智科技有限公司(以下简称“知了汇智”)组织开展“深化产教融合、聚焦人才培养”的主题座谈会议,联动高校与合作企业参加,…...
MyBatis缓存-一级缓存--二级缓存的非常详细的介绍
目录 MyBatis-缓存-提高检索效率的利器 缓存-官方文档 一级缓存 基本说明 一级缓存原理图 代码演示 修改MonsterMapperTest.java, 增加测试方法 结果 debug 一级缓存执行流程 一级缓存失效分析 关闭sqlSession会话后 , 一级缓存失效 如果执行sqlSession.clearCache(…...
macOS Ventura 13.4 RC2(22F63)发布
系统介绍 根据黑果魏叔官网提供:5 月 12 日消息,苹果今天面向开发人员,发布了 macOS Ventura 13.4 的第 2 个候选 RC 版本(内部版本号 22F63),距离上个候选版本相隔数天时间。 macOS Ventura 带来了台前调…...
【为什么可以相信一个HTTPS网站】
解决信用,仅仅有加密和解密是不够的。加密解密解决的只是传输链路的安全问题,相当于两个人说话不被窃听。可以类比成你现在生活 的世界——货币的信用,是由政府在背后支撑的;购房贷款的信用,是由银行在背后支撑的&…...
4.进阶篇
目录 一、按照测试对象划分 1.界面测试(UI测试) 界面测试的常见错误: 2.可靠性测试 3.容错性测试 4.文档测试 5.兼容性测试 6.易用性 7.安装卸载测试 8.安全性测试 9.性能测试 10.内存泄漏 二、按照是否查看代码 1.黑盒测试 2.…...
conda init
在输入conda activate 的时候出现报错: 解决: "需要使用 conda init 进行初始化" 的错误通常是由于你的系统环境缺少 conda 的初始化脚本所致。当你尝试在终端中执行 conda activate 命令时,会出现此错误提示。 要解决这个问题,可以通过以下步骤进行操作: 打…...
Elasticsearch(二)
Clasticsearch(二) DSL查询语法 文档 文档:https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html 常见查询类型包括: 查询所有:查询出所有数据,一般测试用。如:…...
工业视觉检测的8个技术优势
工业4.0时代,自动化生产线成为了这个时代的主旋律,而工业视觉检测技术也成为其中亮眼的表现,其机器视觉技术为设备提供了智慧的双眼,让自动化的脚步得以加速! 在实际的生产应用中,视觉技术方案往往先被着手…...
16 KVM虚拟机配置-其他常见配置项
文章目录 16 KVM虚拟机配置-其他常见配置项16.1 概述16.2 元素介绍16.3 配置示例 16 KVM虚拟机配置-其他常见配置项 16.1 概述 除系统资源和虚拟设备外,XML配置文件还需要配置一些其他元素,本节介绍这些元素的配置方法。 16.2 元素介绍 iothreads&…...
反向工程与模型迁移:打造未来商品详情API的可持续创新体系
在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...
JVM垃圾回收机制全解析
Java虚拟机(JVM)中的垃圾收集器(Garbage Collector,简称GC)是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象,从而释放内存空间,避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...
c++ 面试题(1)-----深度优先搜索(DFS)实现
操作系统:ubuntu22.04 IDE:Visual Studio Code 编程语言:C11 题目描述 地上有一个 m 行 n 列的方格,从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子,但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...
将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?
Otsu 是一种自动阈值化方法,用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理,能够自动确定一个阈值,将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...
鱼香ros docker配置镜像报错:https://registry-1.docker.io/v2/
使用鱼香ros一件安装docker时的https://registry-1.docker.io/v2/问题 一键安装指令 wget http://fishros.com/install -O fishros && . fishros出现问题:docker pull 失败 网络不同,需要使用镜像源 按照如下步骤操作 sudo vi /etc/docker/dae…...
什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南
文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...
NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合
在汽车智能化的汹涌浪潮中,车辆不再仅仅是传统的交通工具,而是逐步演变为高度智能的移动终端。这一转变的核心支撑,来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒(T-Box)方案:NXP S32K146 与…...
Java毕业设计:WML信息查询与后端信息发布系统开发
JAVAWML信息查询与后端信息发布系统实现 一、系统概述 本系统基于Java和WML(无线标记语言)技术开发,实现了移动设备上的信息查询与后端信息发布功能。系统采用B/S架构,服务器端使用Java Servlet处理请求,数据库采用MySQL存储信息࿰…...
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的“no matching...“系列算法协商失败问题
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的"no matching..."系列算法协商失败问题 摘要: 近期,在使用较新版本的OpenSSH客户端连接老旧SSH服务器时,会遇到 "no matching key exchange method found", "n…...
