当前位置: 首页 > news >正文

【Pytorch】Visualization of Feature Maps(4)——Saliency Maps

在这里插入图片描述

学习参考来自

  • Saliency Maps的原理与简单实现(使用Pytorch实现)
  • https://github.com/wmn7/ML_Practice/tree/master/2019_07_08/Saliency%20Maps

Saliency Maps 原理

《Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps》(arXiv-2013)

在这里插入图片描述

A saliency map tells us the degree to which each pixel in the image affects the classification score for that image.
To compute it, we compute the gradient of the unnormalized score corresponding to the correct class (which is a scalar)
with respect to the pixels of the image. If the image has shape (3, H, W) then this gradient will also have shape (3, H, W);
for each pixel in the image, this gradient tells us the amount by which the classification score will change if the pixel
changes by a small amount. To compute the saliency map, we take the absolute value of this gradient, then take the maximum value over the 3 input channels; the final saliency map thus has shape (H, W) and all entries are non-negative.

Saliency Maps相当于是计算图像的每一个pixel是如何影响一个分类器的, 或者说分类器对图像中每一个pixel哪些认为是重要的.

会计算图像每一个像素点的梯度。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);对于图像中的每个像素点,
这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。

计算 saliency map 的时候,需要计算出梯度的绝对值,然后再取三个颜色通道的最大值;

因此最后的 saliency map的形状是(H, W)为一个通道的灰度图。


直接来代码,先载入些数据,用的是 cs231n 作业里面的 imagenet_val_25.npz,含有 imagenet 数据中验证集的 25 张图片

import torch
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import ImageSQUEEZENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
SQUEEZENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)def load_imagenet_val(num=None):"""Load a handful of validation images from ImageNet.Inputs:- num: Number of images to load (max of 25)Returns:- X: numpy array with shape [num, 224, 224, 3]- y: numpy array of integer image labels, shape [num]- class_names: dict mapping integer label to class name"""imagenet_fn = 'imagenet_val_25.npz'if not os.path.isfile(imagenet_fn):print('file %s not found' % imagenet_fn)print('Run the following:')print('cd cs231n/datasets')print('bash get_imagenet_val.sh')assert False, 'Need to download imagenet_val_25.npz'f = np.load(imagenet_fn, allow_pickle=True)X = f['X']  # (25, 224, 224, 3)y = f['y']  # (25, )class_names = f['label_map'].item()  # 999if num is not None:X = X[:num]y = y[:num]return X, y, class_names

图像的前处理,resize,变成向量,减均值除以方差

# 辅助函数
def preprocess(img, size=224):transform = T.Compose([T.Resize(size),T.ToTensor(),T.Normalize(mean=SQUEEZENET_MEAN.tolist(),std=SQUEEZENET_STD.tolist()),T.Lambda(lambda x: x[None]),])return transform(img)

在这里插入图片描述

数据集和实验的模型

链接:https://pan.baidu.com/s/1vb2Y0IiHdH_Fb9wibTta4Q?pwd=zuvw
提取码:zuvw


核心代码,计算 saliency maps

def compute_saliency_maps(X, y, model):"""X表示图片, y表示分类结果, model表示使用的分类模型Input : - X : Input images : Tensor of shape (N, 3, H, W)- y : Label for X : LongTensor of shape (N,)- model : A pretrained CNN that will be used to computer the saliency mapReturn :- saliency : A Tensor of shape (N, H, W) giving the saliency maps for the input images"""# 确保model是test模式model.eval()# 确保X是需要gradientX.requires_grad_() # 仅开启了输入图片的梯度saliency = Nonelogits = model.forward(X)  # torch.Size([5, 1000]), 前向获取 logitslogits = logits.gather(1, y.view(-1, 1)).squeeze()  # torch.Size([5]) 得到正确分类 logits (5张图片标签相应类别的 logits)logits.backward(torch.FloatTensor([1., 1., 1., 1., 1.]))  # 只计算正确分类部分的loss(正确类别梯度为 1 回传)saliency = abs(X.grad.data)  # 返回X的梯度绝对值大小, torch.Size([5, 3, 224, 224])saliency, _ = torch.max(saliency, dim=1)  # torch.Size([5, 224, 224]),取 rgb 3通道的最大值return saliency.squeeze()

显示 saliency maps

def show_saliency_maps(X, y):# Convert X and y from numpy arrays to Torch TensorsX_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0) # torch.Size([5, 3, 224, 224])y_tensor = torch.LongTensor(y)# Compute saliency maps for images in Xsaliency = compute_saliency_maps(X_tensor, y_tensor, model)# Convert the saliency map from Torch Tensor to numpy array and show images# and saliency maps together.saliency = saliency.numpy()N = X.shape[0]  # 5for i in range(N):plt.subplot(2, N, i + 1)plt.imshow(X[i])plt.axis('off')plt.title(class_names[y[i]])plt.subplot(2, N, N + i + 1)plt.imshow(saliency[i], cmap=plt.cm.hot)plt.axis('off')plt.gcf().set_size_inches(12, 5)plt.show()

下面开始调用,首先载入模型,使其梯度冻结,仅打开输入图片的梯度,这样反向传播的时候会更新图片,得到我们想要的 saliency maps

# Download and load the pretrained SqueezeNet model.
model = torchvision.models.squeezenet1_1(pretrained=True)# We don't want to train the model, so tell PyTorch not to compute gradients
# with respect to model parameters.
for param in model.parameters():param.requires_grad = False

加载一些图片看看,25 张中抽出来 5 张

X, y, class_names = load_imagenet_val(num=5)  # X: (5, 224, 224, 3) | y: (5,) | class_names: 999"show images"plt.figure(figsize=(12, 6))
for i in range(5):plt.subplot(1, 5, i + 1)plt.imshow(X[i])plt.title(class_names[y[i]])plt.axis('off')
plt.gcf().tight_layout()
plt.show()

显示图片
在这里插入图片描述
把五张图片的 saliency maps 画出来

show_saliency_maps(X, y)

我把 25 张都画出来了
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述


核心代码中涉及到了 gather 函数,下面来个简单的例子就明白了

# Example of using gather to select one entry from each row in PyTorch
# 用来返回matrix指定行某个位置的值
import torchdef gather_example():N, C = 4, 5s = torch.randn(N, C) # 随机生成 4 行 5 列的 tensory = torch.LongTensor([1, 2, 1, 3])print(s)print(y)print(torch.LongTensor(y).view(-1, 1))print(s.gather(1, y.view(-1, 1)).squeeze()) # 抽取每行相应的列数位置上的数值gather_example()"""
tensor([[ 0.8119,  0.2664, -1.4168, -0.1490, -0.0675],[ 0.5335,  0.6304, -0.7200, -0.0974, -0.9934],[-0.8305,  0.5189,  0.7359,  1.5875,  0.0505],[ 0.4335, -1.1389, -0.7771,  0.5779,  0.3515]])
tensor([1, 2, 1, 3])
tensor([[1],[2],[1],[3]])
tensor([ 0.2664, -0.7200,  0.5189,  0.5779])
"""

相关文章:

【Pytorch】Visualization of Feature Maps(4)——Saliency Maps

学习参考来自 Saliency Maps的原理与简单实现(使用Pytorch实现)https://github.com/wmn7/ML_Practice/tree/master/2019_07_08/Saliency%20Maps Saliency Maps 原理 《Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps》&…...

java第三十课

电商项目(前台): 登录接口 注册接口后台: 注册审核:建一个线程类 注意程序中的一个问题。 这里是 5 条记录,2 条记录显示应该是 3 页,实际操作过程 有审核机制,出现了数据记录动态变…...

Scala--2

package scala02object Scala07_typeCast {def main(args: Array[String]): Unit {// TODO 隐式转换// 自动转换val b: Byte 10var i: Int b 10val l: Long b 10 100Lval fl: Float b 10 100L 10.5fval d: Double b 10 100L 10.5f 20.00println(d.getClass…...

【SQL SERVER】定时任务

oracle是定时JOB,sqlserver是创建作业,通过sqlserver代理实现 先看SQL SERVER代理得服务有没有开 选择计算机右键——>管理——>服务与应用程序——>服务——>SQL server 代理 然后把SQL server 代理(MSSQLSERVER)启…...

MyBatis-Plus学习笔记(无脑cv即可)

1.MyBatis-Plus 1.1特性 无侵入:只做增强不做改变,引入它不会对现有工程产生影响,如丝般顺滑损耗小:启动即会自动注入基本 CURD,性能基本无损耗,直接面向对象操作强大的 CRUD 操作:内置通用 M…...

【VUE】watch 监听失效

如果你遇见了这个问题,那么尝试在 watch 函数中设置 { deep: true } 选项。这告诉 Vue 监听对象或数组内部的变化,就像下面这样: watch(()>chatStore.dataSources,(oldValue, newValue)>{// 监听执行逻辑 }, { deep: true })嗯&#x…...

python的异常处理批量执行网络设备的巡检命令

前言 在网络设备数量超过千台甚至上万台的大型企业网中,难免会遇到某些设备的管理IP地址不通,SSH连接失败的情况,设备数量越多,这种情况发生的概率越高。 这个时候如果你想用python批量配置所有的设备,就一定要注意这…...

react native 环境准备

一、必备安装 1、安装node 注意 Node 的版本应大于等于 16,安装完 Node 后建议设置 npm 镜像(淘宝源)以加速后面的过程(或使用科学上网工具)。 node下载地址:Download | Node.js设置淘宝源 npm config s…...

PGSQL(PostgreSQL)数据库安装教程

安装包下载 下载地址 下载后点击exe安装包 设置的data存储路径 设置密码 设置端口 安装完毕,配置PGSQL的ip远程连接,pg_hba.conf,postgresql.conf,需要更改这两个文件 pg_hba.conf 最后增加一行 host all all …...

识别和修复网站上损坏链接的最佳实践

如果您有一个网站,我们知道您花了很多时间在它上面,以使其成为最好的资源。如果你的链接不起作用,你的努力可能是徒劳的。您网站上的断开链接可能会以两种方式损害您的业务: 它们对企业来说是可怕的,因为当消费者点击…...

使用Navicat连接MySQL出现的一些错误

目录 一、错误一:防火墙未关闭 二、错误二:安全组问题 三、错误三:MySQL密码的加密方式 四、错误四:修改my.cnf配置文件 一、错误一:防火墙未关闭 #查看防火墙状态 firewall-cmd --state#关闭防…...

4G基站BBU、RRU、核心网设备

目录 前言 基站 核心网 信号传输 前言 移动运营商在建设4G基站的时候,除了建设一座铁塔之外,更重要的是建设搭载铁塔之上的移动通信设备,这篇博客主要介绍BBU,RRU以及机房的核心网等设备。 基站 一个基站有BBU,…...

iphone/安卓手机如何使用burp抓包

iphone 1. 电脑 ipconfig /all 获取电脑网卡ip: 192.168.31.10 2. 电脑burp上面打开设置,proxy,增加一条 192.168.31.10:8080 3. 4. 手机进入设置 -> Wi-Fi -> 找到HTTP代理选项,选择手动,192.168.31.10:8080 …...

springboot云HIS医院信息综合管理平台源码

满足基层医院机构各类业务需要的健康云HIS系统。该系统能帮助基层医院机构完成日常各类业务,提供病患挂号支持、病患问诊、电子病历、开药发药、会员管理、统计查询、医生站和护士站等一系列常规功能,能与公卫、PACS等各类外部系统融合,实现多…...

【视觉SLAM十四讲学习笔记】第三讲——四元数

专栏系列文章如下: 【视觉SLAM十四讲学习笔记】第一讲——SLAM介绍 【视觉SLAM十四讲学习笔记】第二讲——初识SLAM 【视觉SLAM十四讲学习笔记】第三讲——旋转矩阵 【视觉SLAM十四讲学习笔记】第三讲——旋转向量和欧拉角 本章将介绍视觉SLAM的基本问题之一&#x…...

Linux系统之部署Plik临时文件上传系统

Linux系统之部署Plik临时文件上传系统 一、Plik介绍1.1 Plik简介1.2 Plik特点 二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍 三、检查本地环境3.1 检查本地操作系统版本3.2 检查系统内核版本 四、下载Plik软件包4.1 创建下载目录4.2 下载Plik软件包4.3 查看下载的Plik软件…...

【EI征稿中#先投稿,先送审#】第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024)

第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024) 2024 3rd International Conference on Cyber Security, Artificial Intelligence and Digital Economy 第二届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2023&…...

『亚马逊云科技产品测评』活动征文|基于亚马逊云EC2搭建OA系统

授权声明:本篇文章授权活动官方亚马逊云科技文章转发、改写权,包括不限于在 Developer Centre, 知乎,自媒体平台,第三方开发者媒体等亚马逊云科技官方渠道 亚马逊EC2云服务器(Elastic Compute Cloud)是亚马…...

Mysql更新varchar存储的Josn数据

Mysql更新varchar存储的Josn数据 记录一次mysql操作varchar格式存储的json字符串数据 1、检查版本 -- 版本5.7以上才可以能执行json操作 select version(); 2、创建测试数据 -- 创建测试表及测试数据 CREATE TABLE test_json_table AS SELECT UUID(), {"test1": …...

JSON.stringify与JSON.parse详解与实践

目录 JSON.stringify 简介 主要用途: API 实践1: 实践2: JSON.parse 简介 API 实践1 实践2 JSON.stringify 简介 用于把JavaScript对象、数组、值、布尔值等序列化成字符串形式。 主要用途: 得到的数据通常有以下主…...

vue 基础

双向绑定的原理 双向绑定是一种数据绑定技术,它能够实现数据的自动同步更新,即当用户修改了数据时,界面也会随之自动更新,反之亦然。其原理如下: 数据模型:双向绑定的第一步是建立一个数据模型&#xff0c…...

使用axios下载后端接口返回的文件流格式文件

在实际开发中,我们经常会遇到下载文件的需求,一般情况下接口最好的处理方式为上传到文件对象存储服务器,然后给前端返回一个下载文件的URL,前端直接打开链接下载就可以了,但…在下载数据量大且参数复杂的情况下&#x…...

在macOS上使用Homebrew安装PHP的完整指南

安装最新版本的PHP 步骤1: 安装Homebrew 在安装最新版本的PHP之前,确保你的macOS系统上已经安装了Homebrew。如果尚未安装,打开终端并运行以下命令: /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install…...

图片处理OpenCV IMDecode模式说明【生产问题处理】

OpenCV IMDecode模式说明【生产问题处理】 1 前言 今天售后同事反馈说客户使用我们的图片处理,将PNG图片处理为JPG图片之后,变为了白板。 我们图片处理使用的是openCV来进行处理 2 分析 2.1 图片是否损坏:非标准PNG头部 于是,马…...

吹响AI技术应用的号角

毫无疑问,各企业正围绕各种技术展开一场持续不断的角逐,力争率先取得领先且具创新性的技术进步,AI技术也不例外。疫情期间,全球各地企业的员工纷纷转向居家办公。因此,为轻松实现这一转型并建立起远程办公的新常态&…...

C //例10.1 从键盘输入一些字符,逐个把它们送到磁盘上去,直到用户输入一个“#”为止。

C程序设计 (第四版) 谭浩强 例10.1 例10.1 从键盘输入一些字符,逐个把它们送到磁盘上去,直到用户输入一个“#”为止。 IDE工具:VS2010 Note: 使用不同的IDE工具可能有部分差异。 代码块 方法:使用指针&…...

ARM预取侧信道(Prefetcher Side Channels)攻击与防御

目录 一、预取侧信道简介 1.1 背景:预取分类 二、Arm核会受到影响吗? 2.1 先进的预取器...

数据结构 | 二叉树的各种遍历

数据结构 | 二叉树的各种遍历 文章目录 数据结构 | 二叉树的各种遍历创建节点 && 创建树二叉树的前中后序遍历二叉树节点个数二叉树叶子节点个数二叉树第k层节点个数二叉树查找值为x的节点二叉树求树的高度二叉树的层序遍历判断二叉树是否是完全二叉树 我们本章来实现二…...

Python-赋值运算符(详解)

表示赋值 左侧为变量,右边为值 a b 10#先把10赋值给b,再把b赋值给a 相当于a 10 b 10 链式赋值,但是不推荐,一般一行一个语句,提高可读性,良好的代码风格 多元赋值: a , b 10,20 #python语…...

算法工程师面试八股(搜广推方向)

文章目录 机器学习线性和逻辑回归模型逻辑回归二分类和多分类的损失函数二分类为什么用交叉熵损失而不用MSE损失?偏差与方差Layer Normalization 和 Batch NormalizationSVM数据不均衡特征选择排序模型树模型进行特征工程的原因GBDTLR和GBDTRF和GBDTXGBoost二阶泰勒…...