第58步 深度学习图像识别:Transformer可视化(Pytorch)
一、写在前面
(1)pytorch_grad_cam库
这一期补上基于基于Transformer框架可视化的教程和代码,使用的是pytorch_grad_cam库,以Bottleneck Transformer模型为例。
(2)算法分类
pytorch_grad_cam库中包含的主要方法有以下几种:
GradCAM: 这是最基本的方法。GradCAM(Gradient-weighted Class Activation Mapping)通过取网络最后一个卷积层的特征图,然后对这些特征图进行加权求和,得到类别激活图。加权的系数是网络最后一个卷积层特征图对应类别的梯度的全局平均池化值。
GradCAMPlusPlus: 这是在GradCAM的基础上的改进。GradCAM++不仅计算了类别相对于特征图的梯度,还计算了二阶和三阶导数。这使得GradCAM++在某些情况下可以获得更细粒度的解释。
ScoreCAM: ScoreCAM采用了不同的策略。它对于每个特征图都生成一个类似的激活图,并将所有这些激活图加权求和。权重是每个特征图对应的类别分数。
AblationCAM: AblationCAM是基于Ablation-based的方法。它首先对每个特征图进行遮挡(或移除),然后看类别得分如何改变。这些改变被用来生成类别激活图。
XGradCAM: 这是GradCAM的另一个扩展。XGradCAM考虑了激活和梯度之间的空间关系,以生成更详细的类别激活图。
EigenCAM: 它基于主成分分析 (PCA) 的方法,利用协方差矩阵的特征向量和特征值来表示激活图。
FullGrad: FullGrad是一个对输入,权重和偏差的特征重要性进行全局分解的方法。
以上方法都在解释深度学习模型的决策,可以帮助理解模型关注的区域和特征。在选择使用哪种方法时,可以根据需求和实验效果进行选择。
二、Transformer可视化实战
继续使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。
(a)Bottleneck Transformer建模
######################################导入包###################################
# 导入必要的包
import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as npwarnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")################################导入数据集#####################################
import torch
from torchvision import datasets, transforms
import os# 数据集路径
data_dir = "./MTB"# 图像的大小
img_height = 256
img_width = 256# 数据预处理
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(img_height),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((img_height, img_width)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
full_dataset = datasets.ImageFolder(data_dir)# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.7 * full_size) # 假设训练集占80%
val_size = full_size - train_size # 验证集的大小# 随机分割数据集
torch.manual_seed(0) # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])# 将数据增强应用到训练集
train_dataset.dataset.transform = data_transforms['train']# 创建数据加载器
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes###############################定义模型################################
# 导入必要的库
import torch.nn as nn
import timm# 定义Bottleneck Transformer模型
model = timm.create_model('botnet26t_256', pretrained=True) # 你可以选择适合你需求的BotNet版本
num_ftrs = model.feature_info[-1]['num_chs']# 根据分类任务修改最后一层
model.head.fc = nn.Linear(num_ftrs, len(class_names))# 将模型移至指定设备
model = model.to(device)# 打印模型摘要
print(model)#############################编译模型#########################################
# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.Adam(model.parameters())# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 开始训练模型
num_epochs = 2# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train() # 设置模型为训练模式else:model.eval() # 设置模型为评估模式running_loss = 0.0running_corrects = 0# 遍历数据for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 零参数梯度optimizer.zero_grad()# 前向with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 只在训练模式下进行反向和优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()# 记录每个epoch的loss和accuracyif phase == 'train':train_loss_history.append(epoch_loss)train_acc_history.append(epoch_acc)else:val_loss_history.append(epoch_loss)val_acc_history.append(epoch_acc)print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))print()# 保存模型
torch.save(model.state_dict(), 'botnet_dit_model.pth')
(b)使用GradCAM可视化
在跑之前,得先安装git;然后用git安装pytorch_grad_cam:
安装git容易,无脑输入:
conda install git
安装pytorch_grad_cam也不难:
git clone https://github.com/jacobgil/pytorch-grad-cam.git
cd pytorch-grad-cam
pip install .
然后码代码:
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import timm# 代码1中的函数
def myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens,size))if titles == False:titles="0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor,heatmap=False,shape=(256,256)):np_arr=tensor.detach().numpy()#[0]#对数据进行归一化if np_arr.max()>1 or np_arr.min()<0:np_arr=np_arr-np_arr.min()np_arr=np_arr/np_arr.max()#np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0]==1:# 如果是灰度图像,复制三个通道以创建一个RGB图像np_arr=np.concatenate([np_arr,np_arr,np_arr],axis=0)np_arr=np_arr.transpose((1,2,0))return np_arr# 加载模型
model = timm.create_model('botnet26t_256', pretrained=False)# 更改全连接层以匹配你的类别数
num_ftrs = model.head.fc.in_features
model.head.fc = nn.Linear(num_ftrs, 2) # 假设你的类别数为2model.load_state_dict(torch.load('botnet_dit_model.pth', map_location=device))# 模型转移到相应设备
model = model.to(device)# 你的图像路径
image_path = './MTB/Tuberculosis/Tuberculosis-203.png'# 加载图像
image = Image.open(image_path).convert("RGB")# 使用代码1中定义的图像转换
input_image = data_transforms['val'](image).unsqueeze(0).to(device)# 使用GradCAM
target_layer = model.stages[2][0].conv3_1x1.bn.drop
with GradCAM(model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available()) as cam:target = [ClassifierOutputTarget(1)] # 修改为你的目标类别grayscale_cam = cam(input_tensor=input_image, targets=target)#将热力图结果与原图进行融合rgb_img=tensor2img(input_image.cpu().squeeze())visualization = show_cam_on_image(rgb_img, grayscale_cam[0], use_rgb=True)
myimshows([rgb_img, grayscale_cam[0], visualization],["image","cam","image + cam"])
结果输出如下:
红色区域就是模型认为的“可疑区域”,也就是说模型根据这些区域判断它是Tuberculosis的主要依据。
几个注意事项:
(a)问:代码:‘target = [ClassifierOutputTarget(0)] # 修改为你的目标类别’,这个怎么解释?此外,0和1分别代表什么呢?
答:第一小问:一般来说,ClassifierOutputTarget(0)中的0代表的是你希望将注意力图(CAM)生成针对的类别标签。例如,如果你的两个类别是猫和狗,且在训练数据集中猫的标签是0,狗的标签是1,那么ClassifierOutputTarget(0)将生成猫的注意力图,而ClassifierOutputTarget(1)将生成狗的注意力图。
第二小问:在 PyTorch 中,使用 ImageFolder 函数或类似的数据加载器加载数据时,类别名称列表(class_names)的顺序将决定了类别标签的分配。这意味着类别名称列表的索引将作为类别的标签。在我们的例子中,class_names = ['Normal', 'Tuberculosis'],"Normal" 的索引是 0,所以它的标签是 0;"Tuberculosis" 的索引是 1,所以它的标签是 1。所以ClassifierOutputTarget(0) 将生成"Normal"类别的注意力图,ClassifierOutputTarget(1) 将生成"Tuberculosis"类别的注意力图。
(b)问:代码:‘target_layer = model.stages[2][0].conv3_1x1.conv’,如何选择输出的层?怎么知道模型中有哪些层?
答:第一小问:一般来说,卷积层或者重复结构的最后一层(如 ResNet 中的每个残差块的最后一层)是可行的目标层,因为这些层能保留空间信息,而全连接层则不行,因为它们不再保留空间信息。
第二小问:通过下面代码打印出模型中所有层次的名称:
#打印出模型中所有层次的名称
for name, module in model.named_modules():
print(name)
输出如下:
或者打印出模型的顶层子模块:
#打印模型的顶层子模块
for name, module in model.named_children():print(name)
输出就四个:
stem
stages
final_conv
head
接下来,展示几个层的写法,大家自行体会:
stem.conv2.conv :target_layer = model.stem.conv2.conv
stages.3.1.conv1_1x1:target_layer = model.stages[3][1].conv1_1x1
final_conv:target_layer = model.final_conv
应该找到规律了吧,不详细解释了。每一层输出是不一样的,例如上面三层输出依次如下:
(c)问:如何改用其他7种方法来替代GradCAM?
答:很简单,来到这个代码段:
with GradCAM(model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available()) as cam:target = [ClassifierOutputTarget(0)] # 修改为你的目标类别grayscale_cam = cam(input_tensor=input_image, targets=target)#将热力图结果与原图进行融合rgb_img=tensor2img(input_image.cpu().squeeze())visualization = show_cam_on_image(rgb_img, grayscale_cam[0], use_rgb=True)
myimshows([rgb_img, grayscale_cam[0], visualization],["image","cam","image + cam"])
只需要把GradCAM分别换成GradCAMPlusPlus、ScoreCAM、AblationCAM、XGradCAM、EigenCAM以及FullGrad即可,简单粗暴。
三、写在后面
除了Transformer,pytorch_grad_cam库也可以用在之前提到的CNN的模型上,大家可自行探索哈。
四、数据
链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf
提取码:x3jf
相关文章:
第58步 深度学习图像识别:Transformer可视化(Pytorch)
一、写在前面 (1)pytorch_grad_cam库 这一期补上基于基于Transformer框架可视化的教程和代码,使用的是pytorch_grad_cam库,以Bottleneck Transformer模型为例。 (2)算法分类 pytorch_grad_cam库中包含的…...
angular实现全局组件
之前我们实现全局组件的第一种方式。我们是在定义了组件的时候通过在declares:[component],然后exports出该组件。最后在页面中每次导入该组件,而这次我们将采用另一种方式来实现 1 新建公用组件: navbreadcrumbnavbreadcrumb.component.htmlnavbreadc…...
Spring编程模型(范式)
面向对象编程 契约接口:Aware aware:意识到的 契约接口(Aware)是Spring框架中的一个特性,它允许Bean对象意识到它们所在的环境并与之进行交互,用于提供特定的功能或信息给Bean对象。这些接口通常作为回调接口,在Bean初始化过程…...
Golang GORM 单表删除
删除只有一个操作,delete。也是先找到再去删除。 可以删除单条记录,也可以删除多条记录。 var s Studentdb.Debug().Delete(&s, "age ?", 100)fmt.Println(s)[15.878ms] [rows:1] DELETE FROM student WHERE age 100var s Studentdb.De…...
Windows 下 MySQL 源码学习环境搭建步骤【建议收藏】
【建议收藏】Windows 下如何安装最新版 MySQL 源码学习的调试环境步骤。 作者:芬达 《芬达的数据库学习笔记》公众号作者,开源爱好者,擅长 MySQL、ansible。 本文来源:原创投稿 爱可生开源社区出品,原创内容未经授权不…...
redis总复习
springboot基于redisson实现看门狗锁:Springboot基于Redisson实现Redis分布式可重入锁【案例到源码分析】_springboot redission lock_AP0906424的博客-CSDN博客 springboot基于redis实现设置缓存和过期时间的代码?包括key的设计 https://mbd.baidu.com/ug_share…...
[LeetCode - Python]844. 比较;含退格的字符串(Easy);415. 字符串相加(Easy)
1.题目 844. 比较含退格的字符串(Easy) 1.代码: class Solution:def backspaceCompare(self, s: str, t: str) -> bool:# 暴力法s list(s)t list(t)M 0N 0for i in range(len(s)):i -M if s[i] # :if i > 0 :s.pop(i)s.pop(i-…...
机器学习深度学习——NLP实战(自然语言推断——注意力机制实现)
👨🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——NLP实战(自然语言推断——数据集) 📚订阅专栏:机器学习&…...
mac垃圾清理软件有哪些
随着使用时间的增加,mac系统会产生一些垃圾文件,影响系统的性能和稳定性。为了保持mac系统的高效,用户需要定期使用mac垃圾清理软件来清理系统缓存、日志、语言包等无用文件。CleanMyMac是一款功能强大的mac垃圾清理软件,它可以帮…...
8.18 校招 内推 面经
绿泡泡: neituijunsir 交流裙,内推/实习/校招汇总表格 1、校招 | 小米集团2024届全球校园招聘正式启动(内推) 校招 | 小米集团2024届全球校园招聘正式启动(内推) 2、2023校招总结--软件测试岗位 - 2 2…...
docker的web管理平台docker.ui
docker.ui安装 docker run --name docker.ui \ -p 8999:8999 \ --restartalways \ -v /var/run/docker.sock:/var/run/docker.sock \ -d joinsunsoft/docker.ui参数说明: docker run:启动container–name:容器命名–restartalwaysÿ…...
20230822 Windows上使用find_package引入OpenCV报错
报错信息 打开Cmake项目时,find_package 报错: Found OpenCV Windows Pack but it has no binaries compatible with yourconfiguration.You should manually point CMake variable OpenCV_DIR to your build of OpenCVlibrary.原因 大概率原项目是在 …...
MySQL下载安装配置
天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…...
3D WEB轻量化引擎HOOPS产品助力NAPA打造船舶设计软件平台
NAPA(Naval Architectural PAckage,船舶建筑包),来自芬兰的船舶设计软件供应商,致力于提供世界领先的船舶设计、安全及运营的解决方案和数据分析服务。NAPA拥有超过30年的船舶设计经验,年营业额超过2560万欧…...
lesson9: C++多线程
1.线程库 1.1 thread类的简单介绍 C11 中引入了对 线程的支持 了,使得 C 在 并行编程时 不需要依赖第三方库 而且在原子操作中还引入了 原子类 的概念。要使用标准库中的线程,必须包含 < thread > 头文件 函数名 功能 thread() 构造一个线程对象…...
安卓修改SwitchCompat色值
SwitchCompat控件色值跟系统设置的主题有关,但是主题效果不是能轻易就能改的,因为涉及到整个APP的样式。网上方案基本都是通过修改style文件来改变色值,经过多次尝试修改最终觉得单独修改控件色值比较好。 一、控件属性 //修改开关色值就是最…...
pytorch内存泄漏
问题描述: 内存泄漏积累过多最终会导致内存溢出,当内存占用过大,进程会被killed掉。 解决过程: 在代码的运行阶段输出内存占用量,观察在哪一块存在内存剧烈增加或者显存异常变化的情况。但是在这个过程中要分级确认…...
20230821-字符串相乘-给树命名(unordered_map)
字符串相乘 有两个非负整数字符串num1,num2,计算num1和num2所表达整数的乘积,结果以字符串形式存储。注意:不能通过强制转换方法解题。 示例1: 输入: "4", "3" 输出: "12" …...
[Go版]算法通关村第十二关黄金——字符串冲刺题
目录 题目:最长公共前缀解法1:纵向对比-循环内套循环写法复杂度:时间复杂度 O ( n ∗ m ) O(n*m) O(n∗m)、空间复杂度 O ( 1 ) O(1) O(1)Go代码 解法2:横向对比-两两对比(类似合并K个数组、合并K个链表)复…...
neovim为工作区添加本地clangd配置
1 背景 尝试使用neovim开发stm32,使用clangd作为LSP提供代码补全等功能。 2 思路 使用stm32cubeMX生成一个基于makefile的stm32工程。 使用bear或compiledb基于makefile生成compile_commands.json文件。 为clangd配置--query-driver选项,使其使用arm…...
信号处理--基于EEG脑电信号的眼睛状态的分析
本实验为生物信息学专题设计小项目。项目目的是通过提供的14导联EEG 脑电信号,实现对于人体睁眼和闭眼两个状态的数据分类分析。每个脑电信号的时长大约为117秒。 目录 加载相关的库函数 读取脑电信号数据并查看数据的属性 绘制脑电多通道连接矩阵 绘制两类数据…...
Redis高可用:主从复制详解
目录 1.什么是主从复制? 2.优势 3.主从复制的原理 4.全量复制和增量复制 4.1 全量复制 4.2 增量复制 5.相关问题总结 5.1 当主服务器不进行持久化时复制的安全性 5.2 为什么主从全量复制使用RDB而不使用AOF? 5.3 为什么还有无磁盘复制模式ÿ…...
[Flutter]有的时候调用setState(() {})报错?
先看FlutterSDK的原生类State中有一个变量mounted。 abstract class State<T extends StatefulWidget> with Diagnosticable {/// mounted的作用是,此State对象当前是否在树中。/// 在创建State对象之后,在调用initState之前,框架通过…...
利用屏幕水印学习英语单词,无打扰英语单词学习
1、利用屏幕水印学习英语单词,不影响任何鼠标键盘操作,不影响工作 2、利用系统热键快速隐藏(ALT1键 隐藏与显示) 3、日积月累单词会有进步 4、软件下载地址: 免安装,代码未加密,安全的屏幕水印学习英语…...
开学必备物品清单!这几款优先考虑!
马上就要开学了,同学们也要准备一系列开学用品,方便我们的学习生活,那有哪些数码物品可以在开学前准备的呢,接下来给大家安利几款很不错很实用的数码好物! 推荐一:南卡00压开放式蓝牙耳机 南卡00压开放式…...
聊聊调制解调器
目录 1.什么是调制解调器 2.调制解调器的工作原理 3.调制解调器的作用 4.调制解调器未来发展 1.什么是调制解调器 调制解调器(Modem)是一种用于在数字设备和模拟设备之间进行数据传输的设备。调制解调器将数字数据转换为模拟信号进行传输,…...
Go语言入门指南:基础语法和常用特性(下)
上一节,我们了解Go语言特性以及第一个Go语言程序——Hello World,这一节就让我们更深入的了解一下Go语言的**基础语法**吧! 一、行分隔符 在 Go 程序中,一行代表一个语句结束。每个语句不需要像 C 家族中的其它语言一样以分号 ;…...
【MFC常用问题记录】
MFC 记录 MFC的edit control控件显示1.控件添加变量M_edit后:2.控件ID为IDC_EDIT1: 线程函数使用 MFC的edit control控件显示 1.控件添加变量M_edit后: CString str; int x 10; str.Format(_T("%d"),x); M_edit.SetWindowText(str)2.控件ID…...
ThreadLocal内存泄漏问题
引子: 内存泄漏:是指本应该被GC回收的无用对象没有被回收,导致内存空间的浪费,当内存泄露严重时会导致内存溢出。Java内存泄露的根本原因是:长生命周期的对象持有短生命周期对象的引用,尽管短生命周期对象已…...
微服务基础概念【内含图解】
目录 拓展补充: 单体架构 分布式架构 面向服务的体系结构 云原生 微服务架构 什么是微服务? 微服务定义 拓展补充: 单体架构 单体架构:将业务的所有功能集中在一个项目中开发,最终打成一个包部署 优点&#x…...
大连网站制作 姚喜运/杭州优化公司在线留言
凌晨四点的东方是什么样子的,这个世界上恐怕只有科比和加班的程序员知道。 如果不是隔壁的高楼挡住了我的视线,我想东方的鱼肚白一定被我尽收眼底。想想都很多年没有去亲眼看日出了,但j是这却是这近一年第四次通宿加班了,也是游戏…...
建站行业消失了吗/百度地图导航2022最新版
touch事件一直是Android学习者一个头疼的问题,为了更好的学习事件的传递,我们来探索一下 touch事件有如下几种: Activity中包括: dispatchTouchEvent onTouchEvent ViewGroup有: dispatchTouchEvent onTouchEvent onIn…...
网络销售怎么做网站/网站排名优化服务
https://segmentfault.com/a/11900000093296191 参考二: https://www.jianshu.com/p/c51ffebeceed 流程图 转载于:https://www.cnblogs.com/wanqingcui/p/9814091.html...
专做女装拿货的网站/关键词搜索排名公司
传纸条 Time Limit: 1000ms Memory limit: 65536K 有疑问?点这里^_^ 题目描述 传纸条是一种在课堂上传递信息的老方法,虽然现在手机短信和QQ聊天越来越普及,但是手写的信息会让人感到一种亲切感。对许多学生而言&…...
wordpress url改变/西安网站定制开发
课程简介 第13天,怎样在SharePoint 2013中审批、拒绝列表项。 视频 SharePoint 2013 交流群 41032413 转载于:https://www.cnblogs.com/OceanEyes/p/day-13-sharepoint-approve-and-reject.html...
中国最大的库存尾货清货平台/seo关键词排名优化品牌
近来公司redmine服务器表现很糟糕,在16核,64GRAM的机器上,压测结果竟然只有每秒5~7个请求,部分页面一个都出不来。 以下是我对Redmine性能优化方案: redmine服务器性能问题排查与优化建议: 以下建议的方案是基于redmi…...