第58步 深度学习图像识别:Transformer可视化(Pytorch)
一、写在前面
(1)pytorch_grad_cam库
这一期补上基于基于Transformer框架可视化的教程和代码,使用的是pytorch_grad_cam库,以Bottleneck Transformer模型为例。
(2)算法分类
pytorch_grad_cam库中包含的主要方法有以下几种:
GradCAM: 这是最基本的方法。GradCAM(Gradient-weighted Class Activation Mapping)通过取网络最后一个卷积层的特征图,然后对这些特征图进行加权求和,得到类别激活图。加权的系数是网络最后一个卷积层特征图对应类别的梯度的全局平均池化值。
GradCAMPlusPlus: 这是在GradCAM的基础上的改进。GradCAM++不仅计算了类别相对于特征图的梯度,还计算了二阶和三阶导数。这使得GradCAM++在某些情况下可以获得更细粒度的解释。
ScoreCAM: ScoreCAM采用了不同的策略。它对于每个特征图都生成一个类似的激活图,并将所有这些激活图加权求和。权重是每个特征图对应的类别分数。
AblationCAM: AblationCAM是基于Ablation-based的方法。它首先对每个特征图进行遮挡(或移除),然后看类别得分如何改变。这些改变被用来生成类别激活图。
XGradCAM: 这是GradCAM的另一个扩展。XGradCAM考虑了激活和梯度之间的空间关系,以生成更详细的类别激活图。
EigenCAM: 它基于主成分分析 (PCA) 的方法,利用协方差矩阵的特征向量和特征值来表示激活图。
FullGrad: FullGrad是一个对输入,权重和偏差的特征重要性进行全局分解的方法。
以上方法都在解释深度学习模型的决策,可以帮助理解模型关注的区域和特征。在选择使用哪种方法时,可以根据需求和实验效果进行选择。
二、Transformer可视化实战
继续使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。
(a)Bottleneck Transformer建模
######################################导入包###################################
# 导入必要的包
import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as npwarnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")################################导入数据集#####################################
import torch
from torchvision import datasets, transforms
import os# 数据集路径
data_dir = "./MTB"# 图像的大小
img_height = 256
img_width = 256# 数据预处理
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(img_height),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((img_height, img_width)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
full_dataset = datasets.ImageFolder(data_dir)# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.7 * full_size) # 假设训练集占80%
val_size = full_size - train_size # 验证集的大小# 随机分割数据集
torch.manual_seed(0) # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])# 将数据增强应用到训练集
train_dataset.dataset.transform = data_transforms['train']# 创建数据加载器
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes###############################定义模型################################
# 导入必要的库
import torch.nn as nn
import timm# 定义Bottleneck Transformer模型
model = timm.create_model('botnet26t_256', pretrained=True) # 你可以选择适合你需求的BotNet版本
num_ftrs = model.feature_info[-1]['num_chs']# 根据分类任务修改最后一层
model.head.fc = nn.Linear(num_ftrs, len(class_names))# 将模型移至指定设备
model = model.to(device)# 打印模型摘要
print(model)#############################编译模型#########################################
# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.Adam(model.parameters())# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 开始训练模型
num_epochs = 2# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train() # 设置模型为训练模式else:model.eval() # 设置模型为评估模式running_loss = 0.0running_corrects = 0# 遍历数据for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 零参数梯度optimizer.zero_grad()# 前向with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 只在训练模式下进行反向和优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()# 记录每个epoch的loss和accuracyif phase == 'train':train_loss_history.append(epoch_loss)train_acc_history.append(epoch_acc)else:val_loss_history.append(epoch_loss)val_acc_history.append(epoch_acc)print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))print()# 保存模型
torch.save(model.state_dict(), 'botnet_dit_model.pth')
(b)使用GradCAM可视化
在跑之前,得先安装git;然后用git安装pytorch_grad_cam:
安装git容易,无脑输入:
conda install git
安装pytorch_grad_cam也不难:
git clone https://github.com/jacobgil/pytorch-grad-cam.git
cd pytorch-grad-cam
pip install .
然后码代码:
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import timm# 代码1中的函数
def myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens,size))if titles == False:titles="0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor,heatmap=False,shape=(256,256)):np_arr=tensor.detach().numpy()#[0]#对数据进行归一化if np_arr.max()>1 or np_arr.min()<0:np_arr=np_arr-np_arr.min()np_arr=np_arr/np_arr.max()#np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0]==1:# 如果是灰度图像,复制三个通道以创建一个RGB图像np_arr=np.concatenate([np_arr,np_arr,np_arr],axis=0)np_arr=np_arr.transpose((1,2,0))return np_arr# 加载模型
model = timm.create_model('botnet26t_256', pretrained=False)# 更改全连接层以匹配你的类别数
num_ftrs = model.head.fc.in_features
model.head.fc = nn.Linear(num_ftrs, 2) # 假设你的类别数为2model.load_state_dict(torch.load('botnet_dit_model.pth', map_location=device))# 模型转移到相应设备
model = model.to(device)# 你的图像路径
image_path = './MTB/Tuberculosis/Tuberculosis-203.png'# 加载图像
image = Image.open(image_path).convert("RGB")# 使用代码1中定义的图像转换
input_image = data_transforms['val'](image).unsqueeze(0).to(device)# 使用GradCAM
target_layer = model.stages[2][0].conv3_1x1.bn.drop
with GradCAM(model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available()) as cam:target = [ClassifierOutputTarget(1)] # 修改为你的目标类别grayscale_cam = cam(input_tensor=input_image, targets=target)#将热力图结果与原图进行融合rgb_img=tensor2img(input_image.cpu().squeeze())visualization = show_cam_on_image(rgb_img, grayscale_cam[0], use_rgb=True)
myimshows([rgb_img, grayscale_cam[0], visualization],["image","cam","image + cam"])
结果输出如下:
红色区域就是模型认为的“可疑区域”,也就是说模型根据这些区域判断它是Tuberculosis的主要依据。
几个注意事项:
(a)问:代码:‘target = [ClassifierOutputTarget(0)] # 修改为你的目标类别’,这个怎么解释?此外,0和1分别代表什么呢?
答:第一小问:一般来说,ClassifierOutputTarget(0)中的0代表的是你希望将注意力图(CAM)生成针对的类别标签。例如,如果你的两个类别是猫和狗,且在训练数据集中猫的标签是0,狗的标签是1,那么ClassifierOutputTarget(0)将生成猫的注意力图,而ClassifierOutputTarget(1)将生成狗的注意力图。
第二小问:在 PyTorch 中,使用 ImageFolder 函数或类似的数据加载器加载数据时,类别名称列表(class_names)的顺序将决定了类别标签的分配。这意味着类别名称列表的索引将作为类别的标签。在我们的例子中,class_names = ['Normal', 'Tuberculosis'],"Normal" 的索引是 0,所以它的标签是 0;"Tuberculosis" 的索引是 1,所以它的标签是 1。所以ClassifierOutputTarget(0) 将生成"Normal"类别的注意力图,ClassifierOutputTarget(1) 将生成"Tuberculosis"类别的注意力图。
(b)问:代码:‘target_layer = model.stages[2][0].conv3_1x1.conv’,如何选择输出的层?怎么知道模型中有哪些层?
答:第一小问:一般来说,卷积层或者重复结构的最后一层(如 ResNet 中的每个残差块的最后一层)是可行的目标层,因为这些层能保留空间信息,而全连接层则不行,因为它们不再保留空间信息。
第二小问:通过下面代码打印出模型中所有层次的名称:
#打印出模型中所有层次的名称
for name, module in model.named_modules():
print(name)
输出如下:
或者打印出模型的顶层子模块:
#打印模型的顶层子模块
for name, module in model.named_children():print(name)
输出就四个:
stem
stages
final_conv
head
接下来,展示几个层的写法,大家自行体会:
stem.conv2.conv :target_layer = model.stem.conv2.conv
stages.3.1.conv1_1x1:target_layer = model.stages[3][1].conv1_1x1
final_conv:target_layer = model.final_conv
应该找到规律了吧,不详细解释了。每一层输出是不一样的,例如上面三层输出依次如下:
(c)问:如何改用其他7种方法来替代GradCAM?
答:很简单,来到这个代码段:
with GradCAM(model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available()) as cam:target = [ClassifierOutputTarget(0)] # 修改为你的目标类别grayscale_cam = cam(input_tensor=input_image, targets=target)#将热力图结果与原图进行融合rgb_img=tensor2img(input_image.cpu().squeeze())visualization = show_cam_on_image(rgb_img, grayscale_cam[0], use_rgb=True)
myimshows([rgb_img, grayscale_cam[0], visualization],["image","cam","image + cam"])
只需要把GradCAM分别换成GradCAMPlusPlus、ScoreCAM、AblationCAM、XGradCAM、EigenCAM以及FullGrad即可,简单粗暴。
三、写在后面
除了Transformer,pytorch_grad_cam库也可以用在之前提到的CNN的模型上,大家可自行探索哈。
四、数据
链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf
提取码:x3jf
相关文章:

第58步 深度学习图像识别:Transformer可视化(Pytorch)
一、写在前面 (1)pytorch_grad_cam库 这一期补上基于基于Transformer框架可视化的教程和代码,使用的是pytorch_grad_cam库,以Bottleneck Transformer模型为例。 (2)算法分类 pytorch_grad_cam库中包含的…...
angular实现全局组件
之前我们实现全局组件的第一种方式。我们是在定义了组件的时候通过在declares:[component],然后exports出该组件。最后在页面中每次导入该组件,而这次我们将采用另一种方式来实现 1 新建公用组件: navbreadcrumbnavbreadcrumb.component.htmlnavbreadc…...
Spring编程模型(范式)
面向对象编程 契约接口:Aware aware:意识到的 契约接口(Aware)是Spring框架中的一个特性,它允许Bean对象意识到它们所在的环境并与之进行交互,用于提供特定的功能或信息给Bean对象。这些接口通常作为回调接口,在Bean初始化过程…...

Golang GORM 单表删除
删除只有一个操作,delete。也是先找到再去删除。 可以删除单条记录,也可以删除多条记录。 var s Studentdb.Debug().Delete(&s, "age ?", 100)fmt.Println(s)[15.878ms] [rows:1] DELETE FROM student WHERE age 100var s Studentdb.De…...

Windows 下 MySQL 源码学习环境搭建步骤【建议收藏】
【建议收藏】Windows 下如何安装最新版 MySQL 源码学习的调试环境步骤。 作者:芬达 《芬达的数据库学习笔记》公众号作者,开源爱好者,擅长 MySQL、ansible。 本文来源:原创投稿 爱可生开源社区出品,原创内容未经授权不…...
redis总复习
springboot基于redisson实现看门狗锁:Springboot基于Redisson实现Redis分布式可重入锁【案例到源码分析】_springboot redission lock_AP0906424的博客-CSDN博客 springboot基于redis实现设置缓存和过期时间的代码?包括key的设计 https://mbd.baidu.com/ug_share…...

[LeetCode - Python]844. 比较;含退格的字符串(Easy);415. 字符串相加(Easy)
1.题目 844. 比较含退格的字符串(Easy) 1.代码: class Solution:def backspaceCompare(self, s: str, t: str) -> bool:# 暴力法s list(s)t list(t)M 0N 0for i in range(len(s)):i -M if s[i] # :if i > 0 :s.pop(i)s.pop(i-…...

机器学习深度学习——NLP实战(自然语言推断——注意力机制实现)
👨🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——NLP实战(自然语言推断——数据集) 📚订阅专栏:机器学习&…...

mac垃圾清理软件有哪些
随着使用时间的增加,mac系统会产生一些垃圾文件,影响系统的性能和稳定性。为了保持mac系统的高效,用户需要定期使用mac垃圾清理软件来清理系统缓存、日志、语言包等无用文件。CleanMyMac是一款功能强大的mac垃圾清理软件,它可以帮…...
8.18 校招 内推 面经
绿泡泡: neituijunsir 交流裙,内推/实习/校招汇总表格 1、校招 | 小米集团2024届全球校园招聘正式启动(内推) 校招 | 小米集团2024届全球校园招聘正式启动(内推) 2、2023校招总结--软件测试岗位 - 2 2…...
docker的web管理平台docker.ui
docker.ui安装 docker run --name docker.ui \ -p 8999:8999 \ --restartalways \ -v /var/run/docker.sock:/var/run/docker.sock \ -d joinsunsoft/docker.ui参数说明: docker run:启动container–name:容器命名–restartalwaysÿ…...

20230822 Windows上使用find_package引入OpenCV报错
报错信息 打开Cmake项目时,find_package 报错: Found OpenCV Windows Pack but it has no binaries compatible with yourconfiguration.You should manually point CMake variable OpenCV_DIR to your build of OpenCVlibrary.原因 大概率原项目是在 …...

MySQL下载安装配置
天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…...

3D WEB轻量化引擎HOOPS产品助力NAPA打造船舶设计软件平台
NAPA(Naval Architectural PAckage,船舶建筑包),来自芬兰的船舶设计软件供应商,致力于提供世界领先的船舶设计、安全及运营的解决方案和数据分析服务。NAPA拥有超过30年的船舶设计经验,年营业额超过2560万欧…...

lesson9: C++多线程
1.线程库 1.1 thread类的简单介绍 C11 中引入了对 线程的支持 了,使得 C 在 并行编程时 不需要依赖第三方库 而且在原子操作中还引入了 原子类 的概念。要使用标准库中的线程,必须包含 < thread > 头文件 函数名 功能 thread() 构造一个线程对象…...
安卓修改SwitchCompat色值
SwitchCompat控件色值跟系统设置的主题有关,但是主题效果不是能轻易就能改的,因为涉及到整个APP的样式。网上方案基本都是通过修改style文件来改变色值,经过多次尝试修改最终觉得单独修改控件色值比较好。 一、控件属性 //修改开关色值就是最…...

pytorch内存泄漏
问题描述: 内存泄漏积累过多最终会导致内存溢出,当内存占用过大,进程会被killed掉。 解决过程: 在代码的运行阶段输出内存占用量,观察在哪一块存在内存剧烈增加或者显存异常变化的情况。但是在这个过程中要分级确认…...
20230821-字符串相乘-给树命名(unordered_map)
字符串相乘 有两个非负整数字符串num1,num2,计算num1和num2所表达整数的乘积,结果以字符串形式存储。注意:不能通过强制转换方法解题。 示例1: 输入: "4", "3" 输出: "12" …...

[Go版]算法通关村第十二关黄金——字符串冲刺题
目录 题目:最长公共前缀解法1:纵向对比-循环内套循环写法复杂度:时间复杂度 O ( n ∗ m ) O(n*m) O(n∗m)、空间复杂度 O ( 1 ) O(1) O(1)Go代码 解法2:横向对比-两两对比(类似合并K个数组、合并K个链表)复…...
neovim为工作区添加本地clangd配置
1 背景 尝试使用neovim开发stm32,使用clangd作为LSP提供代码补全等功能。 2 思路 使用stm32cubeMX生成一个基于makefile的stm32工程。 使用bear或compiledb基于makefile生成compile_commands.json文件。 为clangd配置--query-driver选项,使其使用arm…...
OpenLayers 可视化之热力图
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 热力图(Heatmap)又叫热点图,是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...
DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径
目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

AI Agent与Agentic AI:原理、应用、挑战与未来展望
文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例:使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例:使用OpenAI GPT-3进…...
uni-app学习笔记二十二---使用vite.config.js全局导入常用依赖
在前面的练习中,每个页面需要使用ref,onShow等生命周期钩子函数时都需要像下面这样导入 import {onMounted, ref} from "vue" 如果不想每个页面都导入,需要使用node.js命令npm安装unplugin-auto-import npm install unplugin-au…...
A2A JS SDK 完整教程:快速入门指南
目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库ÿ…...
Java求职者面试指南:计算机基础与源码原理深度解析
Java求职者面试指南:计算机基础与源码原理深度解析 第一轮提问:基础概念问题 1. 请解释什么是进程和线程的区别? 面试官:进程是程序的一次执行过程,是系统进行资源分配和调度的基本单位;而线程是进程中的…...

[大语言模型]在个人电脑上部署ollama 并进行管理,最后配置AI程序开发助手.
ollama官网: 下载 https://ollama.com/ 安装 查看可以使用的模型 https://ollama.com/search 例如 https://ollama.com/library/deepseek-r1/tags # deepseek-r1:7bollama pull deepseek-r1:7b改token数量为409622 16384 ollama命令说明 ollama serve #:…...
Bean 作用域有哪些?如何答出技术深度?
导语: Spring 面试绕不开 Bean 的作用域问题,这是面试官考察候选人对 Spring 框架理解深度的常见方式。本文将围绕“Spring 中的 Bean 作用域”展开,结合典型面试题及实战场景,帮你厘清重点,打破模板式回答,…...

解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用
在工业制造领域,无损检测(NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统,以非接触式光学麦克风技术为核心,打破传统检测瓶颈,为半导体、航空航天、汽车制造等行业提供了高灵敏…...

Vue ③-生命周期 || 脚手架
生命周期 思考:什么时候可以发送初始化渲染请求?(越早越好) 什么时候可以开始操作dom?(至少dom得渲染出来) Vue生命周期: 一个Vue实例从 创建 到 销毁 的整个过程。 生命周期四个…...