当前位置: 首页 > news >正文

PyTorch实战:模型训练中的特征图可视化技巧

1.特征图可视化,这种方法是最简单,输入一张照片,然后把网络中间某层的输出的特征图按通道作为图片进行可视化展示即可。

2.特征图可视化代码如下:

def featuremap_visual(feature, out_dir=None,  # 特征图保存路径文件save_feature=True,  # 是否以图片形式保存特征图show_feature=True,  # 是否使用plt显示特征图feature_title=None,  # 特征图名字,默认以shape作为titlenum_ch=-1,  # 显示特征图前几个通道,-1 or None 都显示nrow=8,  # 每行显示多少个特征图通道padding=10,  # 特征图之间间隔多少像素值pad_value=1  # 特征图之间的间隔像素):import matplotlib.pylab as pltimport torchvisionimport os# feature = feature.detach().cpu()b, c, h, w = feature.shapefeature = feature[0]feature = feature.unsqueeze(1)if c > num_ch > 0:feature = feature[:num_ch]img = torchvision.utils.make_grid(feature, nrow=nrow, padding=padding, pad_value=pad_value)img = img.detach().cpu()img = img.numpy()images = img.transpose((1, 2, 0))# title = str(images.shape) if feature_title is None else str(feature_title)title = str('hwc-') + str(h) + '-' + str(w) + '-' + str(c) if feature_title is None else str(feature_title)plt.title(title)plt.imshow(images)if save_feature:# root=r'C:\Users\Administrator\Desktop\CODE_TJ\123'# plt.savefig(os.path.join(root,'1.jpg'))out_root = title + '.jpg' if out_dir == '' or out_dir is None else os.path.join(out_dir, title + '.jpg')plt.savefig(out_root)if show_feature:        plt.show()

3.结合resnet网络整体可视化(主要将其featuremap_visual函数插入forward中,即可),整体代码如下:

resnet网络结构在我博客:
残差网络ResNet(超详细代码解析) :你必须要知道backbone模块成员之一 - tangjunjun - 博客园

"""
@author: tangjun
@contact: 511026664@qq.com
@time: 2020/12/7 22:48
@desc: 残差ackbone改写,用于构建特征提取模块
"""import torch.nn as nn
import torch
from collections import OrderedDictdef Conv(in_planes, out_planes, **kwargs):"3x3 convolution with padding"padding = kwargs.get('padding', 1)bias = kwargs.get('bias', False)stride = kwargs.get('stride', 1)kernel_size = kwargs.get('kernel_size', 3)out = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)return outclass BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = Conv(inplanes, planes, stride=stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = Conv(planes, planes)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None):super(Bottleneck, self).__init__()self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(planes * 4)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outclass Resnet(nn.Module):arch_settings = {18: (BasicBlock, (2, 2, 2, 2)),34: (BasicBlock, (3, 4, 6, 3)),50: (Bottleneck, (3, 4, 6, 3)),101: (Bottleneck, (3, 4, 23, 3)),152: (Bottleneck, (3, 8, 36, 3))}def __init__(self,depth=50,in_channels=None,pretrained=None,frozen_stages=-1# num_classes=None):super(Resnet, self).__init__()self.inplanes = 64self.inchannels = in_channels if in_channels is not None else 3  # 输入通道# self.num_classes=num_classesself.block, layers = self.arch_settings[depth]self.frozen_stages = frozen_stagesself.conv1 = nn.Conv2d(self.inchannels, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(self.block, 64, layers[0], stride=1)self.layer2 = self._make_layer(self.block, 128, layers[1], stride=2)self.layer3 = self._make_layer(self.block, 256, layers[2], stride=2)self.layer4 = self._make_layer(self.block, 512, layers[3], stride=2)# self.avgpool = nn.AvgPool2d(7)# self.fc = nn.Linear(512 * self.block.expansion, self.num_classes)self._freeze_stages()  # 冻结函数if pretrained is not None:self.init_weights(pretrained=pretrained)def _freeze_stages(self):if self.frozen_stages >= 0:self.norm1.eval()for m in [self.conv1, self.norm1]:for param in m.parameters():param.requires_grad = Falsefor i in range(1, self.frozen_stages + 1):m = getattr(self, 'layer{}'.format(i))m.eval()for param in m.parameters():param.requires_grad = Falsedef init_weights(self, pretrained=None):if isinstance(pretrained, str):self.load_checkpoint(pretrained)elif pretrained is None:for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='relu')if hasattr(m, 'bias') and m.bias is not None:  # m包含该属性且m.bias非None # hasattr(对象,属性)表示对象是否包含该属性nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()def load_checkpoint(self, pretrained):checkpoint = torch.load(pretrained)if isinstance(checkpoint, OrderedDict):state_dict = checkpointelif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:state_dict = checkpoint['state_dict']if list(state_dict.keys())[0].startswith('module.'):state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}unexpected_keys = []  # 保存checkpoint不在module中的keymodel_state = self.state_dict()  # 模型变量for name, param in state_dict.items():  # 循环遍历pretrained的权重if name not in model_state:unexpected_keys.append(name)continueif isinstance(param, torch.nn.Parameter):# backwards compatibility for serialized parametersparam = param.datatry:model_state[name].copy_(param)  # 试图赋值给模型except Exception:raise RuntimeError('While copying the parameter named {}, ''whose dimensions in the model are {} not equal ''whose dimensions in the checkpoint are {}.'.format(name, model_state[name].size(), param.size()))missing_keys = set(model_state.keys()) - set(state_dict.keys())print('missing_keys:', missing_keys)def _make_layer(self, block, planes, num_blocks, stride=1):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample))self.inplanes = planes * block.expansionfor i in range(1, num_blocks):layers.append(block(self.inplanes, planes))return nn.Sequential(*layers)def forward(self, x):outs = []x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)outs.append(x)featuremap_visual(x)x = self.layer2(x)outs.append(x)featuremap_visual(x)x = self.layer3(x)outs.append(x)featuremap_visual(x)x = self.layer4(x)outs.append(x)# x = self.avgpool(x)# x = x.view(x.size(0), -1)# x = self.fc(x)return tuple(outs)def featuremap_visual(feature,out_dir=None,  # 特征图保存路径文件save_feature=True,  # 是否以图片形式保存特征图show_feature=True,  # 是否使用plt显示特征图feature_title=None,  # 特征图名字,默认以shape作为titlenum_ch=-1,  # 显示特征图前几个通道,-1 or None 都显示nrow=8,  # 每行显示多少个特征图通道padding=10,  # 特征图之间间隔多少像素值pad_value=1  # 特征图之间的间隔像素):import matplotlib.pylab as pltimport torchvisionimport os# feature = feature.detach().cpu()b, c, h, w = feature.shapefeature = feature[0]feature = feature.unsqueeze(1)if c > num_ch > 0:feature = feature[:num_ch]img = torchvision.utils.make_grid(feature, nrow=nrow, padding=padding, pad_value=pad_value)img = img.detach().cpu()img = img.numpy()images = img.transpose((1, 2, 0))# title = str(images.shape) if feature_title is None else str(feature_title)title = str('hwc-') + str(h) + '-' + str(w) + '-' + str(c) if feature_title is None else str(feature_title)plt.title(title)plt.imshow(images)if save_feature:# root=r'C:\Users\Administrator\Desktop\CODE_TJ\123'# plt.savefig(os.path.join(root,'1.jpg'))out_root = title + '.jpg' if out_dir == '' or out_dir is None else os.path.join(out_dir, title + '.jpg')plt.savefig(out_root)if show_feature:        plt.show()import cv2
import numpy as npdef imnormalize(img,mean=[123.675, 116.28, 103.53],std=[58.395, 57.12, 57.375],to_rgb=True):if to_rgb:img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = img.astype(np.float32)return (img - mean) / stdif __name__ == '__main__':import matplotlib.pylab as pltimg = cv2.imread('1.jpg')  # 读取图片img = imnormalize(img)img = torch.from_numpy(img)img = torch.unsqueeze(img, 0)img = img.permute(0, 3, 1, 2)img = torch.tensor(img, dtype=torch.float32)img = img.to('cuda:0')model = Resnet(depth=50)model.init_weights(pretrained='./resnet50.pth')  # 可以使用,也可以注释model = model.cuda()out = model(img)

运行结果

参考:

PyTorch模型训练特征图可视化 - tangjunjun - 博客园 (cnblogs.com)

相关文章:

PyTorch实战:模型训练中的特征图可视化技巧

1.特征图可视化,这种方法是最简单,输入一张照片,然后把网络中间某层的输出的特征图按通道作为图片进行可视化展示即可。 2.特征图可视化代码如下: def featuremap_visual(feature, out_dirNone, # 特征图保存路径文件save_feat…...

有人@你!神工坊知识问答第二期中奖名单新鲜出炉

六月作为伟大的物理学家—麦克斯韦的诞辰月 神工坊特别推出 “ 辨 ‘麦克斯韦妖’,赢百元好礼” 夏日知识问答主题活动 活动一经推出 反响热烈 第二期中奖名单公布! 中奖的伙伴们速来兑奖! 听说还有同学没有参与? 活动最后…...

数据结构篇:旋转操作在AVL树中的实现过程

本节课在线学习视频(网盘地址,保存后即可免费观看): https://pan.quark.cn/s/06d5ed47e33b AVL树是平衡二叉搜索树的一种,它通过旋转操作来保持树的平衡。AVL树的特点是,任何节点的两个子树的高度最大差别…...

为什么Java默认使用UTF-16,Golang默认使用UTF-8呢?

Java 和 Go 语言在默认字符编码上做出了不同的选择,这是由它们的设计目标和使用场景决定的。下面是对 Java 默认使用 UTF-16 和 Go 默认使用 UTF-8 的原因进行的详细解释。 Java 默认使用 UTF-16 的原因 1. 历史背景和兼容性 Unicode 的发展: Java 诞生于 1995 年…...

JavaScript常见面试题(三)

文章目录 1.对原型、原型链的理解2.原型修改、重写3.原型链指向4.对闭包的理解5. 对作用域、作用域链的理解6.对执行上下文的理解7.对this对象的理解8. call() 和 apply() 的区别?9.异步编程的实现方式?10.setTimeout、Promise、Async/Await 的区别11.对…...

【Effective Modern C++】第1章 型别推导

【Effective Modern C】第1章 型别推导 文章目录 【Effective Modern C】第1章 型别推导条款1:理解模板型别推导基础概念模板型别推导的三种情况情景一 ParamType 是一个指针或者引用,但非通用引用情景二 ParamType是一个通过引用情景三 ParamType既不是…...

服装连锁实体店bC一体化运营方案

一、引言 随着互联网的快速发展和消费者购物习惯的变化,传统服装连锁实体店在面对新的市场环境下亟需转型升级。BC(Business to Consumer)一体化运营方案的实施将成为提升服装连锁实体店竞争力和顾客体验的关键举掖。商淘云详细介绍服装连锁…...

IDEA中SpringMVC的运行环境问题

文章目录 一、IEAD 清理缓存二、用阿里云和spring创建 SpringMVC 项目中 pom.xml 文件的区别 一、IEAD 清理缓存 springMVC 运行时存在一些之前运行过的缓存导致项目不能运行,可以试试清理缓存 二、用阿里云和spring创建 SpringMVC 项目中 pom.xml 文件的区别 以下…...

Python初体验

# Java基础知识学的差不多了,项目上又没什么事,学学py,方便以后对接 1、打包flask应用(好痛苦,在什么平台打包就只在那个平台可用想在linux用只能参考方法2了) pyinstaller --onefile app.py -n myapp 2…...

从零开始如何学习人工智能?

说说我自己的情况:我接触AI的时候,是在研一。那个时候AlphaGo战胜围棋世界冠军李世石是大新闻,人工智能第一次出现我面前,当时就想搞清楚背后的原理以及这些技术有什么作用。 就开始找资料,看视频。随着了解的深入&am…...

【仿真建模-anylogic】动态生成ConveyorCustomStation

Author:赵志乾 Date:2024-06-18 Declaration:All Right Reserved!!! 0. 背景 直接使用Anylogic组件开发的模型无法动态改变运输网布局;目前需求是要将运输网布局配置化;运输网配置化…...

如何使用idea连接Oracle数据库?

idea版本:2021.3.3 Oracle版本:10.2.0.1.0(在虚拟机Windows sever 2003 远程连接数据库) 数据库管理系统:PLSQL Developer 在idea里面找到database,在idea侧面 选择左上角加号,新建&#xff…...

谈谈kafaka的并行处理,顺带讲讲rabbitmq

简介 Kafka 是一个分布式流处理平台,它支持高效的并行处理。Kafka 的并行处理能力主要体现在以下几个方面: 分区(Partition)并行 Kafka 将数据存储在称为"分区"的逻辑单元中。每个分区可以独立地并行地进行读写操作。生产者可以根据分区策略,将数据写入到指定的分…...

P3056 [USACO12NOV] Clumsy Cows S

[USACO12NOV] Clumsy Cows S 题目描述 Bessie the cow is trying to type a balanced string of parentheses into her new laptop, but she is sufficiently clumsy (due to her large hooves) that she keeps mis-typing characters. Please help her by computing the min…...

智赢选品,OZON数据分析选品利器丨萌啦OZON数据

在电商行业的激烈竞争中,如何快速准确地把握市场动态、洞察消费者需求、实现精准选品,是每个电商卖家都面临的挑战。而在这个数据驱动的时代,一款强大的数据分析工具无疑是电商卖家们的得力助手。今天,我们就来聊聊这样一款选品利…...

Canal自定义客户端

一、背景 在Canal推送数据变更信息至MQ(消息队列)时,我们遇到了特定问题,尤其是当消息体的大小超过了MQ所允许的最大限制。这种限制导致数据推送过程受阻,需要相应的调整或处理。 二、解决方法 采用Canal自定义客户…...

20240621将需要自启动的部分放到RK3588平台的Buildroot系统的rcS文件中

20240621将需要自启动的部分放到RK3588平台的Buildroot系统的rcS文件中 2024/6/21 17:15 开发板:飞凌OK3588-C SDK:Rockchip原厂的Buildroot 缘起:在凌OK3588-C的LINUX R4系统启动的时候,需要拉高GPIO4_B5、GPIO3_B7和GPIO3_D0。…...

掌握数据魔方:Xinstall引领ASA全链路数据归因新纪元

一、引言 在数字化时代,数据是App推广和运营的核心驱动力。然而,如何准确获取、分析并应用这些数据,却成为了许多开发者和营销人员面临的痛点。Xinstall作为一款专业的App全渠道统计服务商,致力于提供精准、高效的数据解决方案&a…...

IIS代理配置-反向代理

前后端分离项目,前端在开发中使用proxy代理解决跨域问题,打包之后无效。 未配置前无法访问 部署环境为windows IIS,要在iis设置反向代理 安装代理模块 需要在iis中实现代理,需要安装Application Request Routing Cache和URL重…...

Flutter调用本地web

前言: 在目前Flutter 环境中,使用在线 webview 是一种很常见的行为 而在 app 环境中,离线使用则更有必要 1.环境准备 将依赖导入 2.引入前端代码 前端代码有两种情况 一种是使用打包工具 build 而来的前端代码 另一种情况是直接使用 HTML 文件 …...

AI大模型部署Ubuntu服务器攻略

一、下载Ollama 在线安装: 在linux中输入命令curl -fsSL https://ollama.com/install.sh | sh 由于在linux下载ollama需要经过外网,网络会不稳定,很容易造成连接超时的问题。 离线安装: 步骤一: 下载Ollama离线版本…...

vlan、vxlan、vpc学习

文章目录 前言VLAN (Virtual Local Area Network)定义工作原理优点应用场景限制 VXLAN (Virtual eXtensible Local Area Network)工作原理优点应用场景与VLAN的区别 VPC (Virtual Private Cloud)定义特点优势应用场景与VLAN/VXLAN的关联 总结 前言 VLAN(Virtual Lo…...

低代码开发:加速工业数智化转型发展

引言 在当今全球经济一体化和信息化的深度融合的大环境下,工业数智化转型已经成为推动制造业高质量发展的关键因素。这一转型不仅涉及生产过程的智能化、网络化,还涉及到企业管理、市场服务等全方位的数字化升级,其最终目标是为了实现更高效能…...

python“__main__“的解读

Tutorial Gross tutorial 有些模块包含了仅供脚本使用的代码,比如解析命令行参数或从标准输入获取数据。 如果这样的模块被从不同的模块中导入,例如为了单元测试,脚本代码也会无意中执行。 这就是 if name ‘main’ 代码块的用武之地。除非…...

Linux Debian12使用podman安装pikachu靶场环境

一、pikachu简介 Pikachu是一个带有漏洞的Web应用系统,在这里包含了常见的web安全漏洞。 二、安装podman环境 Linux Debian系统如果没有安装podman容器环境,可以参考这篇文章先安装podman环境, Linux Debian11使用国内源安装Podman环境 三…...

跑通并使用Yolo v5的源代码并进行训练—目标检测

跑通并使用Yolo v5的源代码并进行训练 摘要:yolo作为目标检测计算机视觉领域的核心网络模型,虽然到24年已经出到了v10的版本,但也很有必要对之前的核心版本v5版本进行进一步的学习。在学习yolo v5的时候因为缺少论文所以要从源代码入手来体验…...

需求虽小但是问题很多,浅谈JavaScript导出excel文件

最近我在进行一些前端小开发,遇到了一个小需求:我想要将数据导出到 Excel 文件,并希望能够封装成一个函数来实现。这个函数需要接收一个二维数组作为参数,数组的第一行是表头。在导出的过程中,要能够确保避免出现中文乱…...

phar反序列化及绕过

目录 一、什么是phar phar://伪协议格式: 二、phar结构 1.stub phar:文件标识。 格式为 xxx; *2、manifest:压缩文件属性等信息,以序列化存 3、contents:压缩文件的内容。 4、signature:签名&#…...

汽车IVI中控开发入门及进阶(三十):视频图像滚动问题分析(imx6+TVP5150+Camera)

前言: DA主控SOC采用imx6,TVP5150作为camera摄像头视频的解码decode芯片,imx6采用linux系统。 关于imx6,请参阅:汽车IVI中控开发入门及进阶(二十九):i.MX6-CSDN博客 Contributor III:...

给PDF添加书签的通解-姜萍同款《偏微分方程》改造手记

背景 网上找了一本姜萍同款的《偏微分方程》,埃文斯,英文版,可惜没有书签,洋洋七百多页,没有书签,怎么读?用福昕编辑器自然能手工一个个加上,可是劳神费力,非程序员所为…...

整容医院网站建设目的/哪家网站优化公司好

今天的收获蛮都的啊,修改了几篇之前写的博文,给他们都加了一下的图片,现在我们再来看看信息队列系统函数的调用吧! 首先我们纵观一下信息队列函数有哪些吧! msgget(),msgsnd&#…...

有没有专做零食批发的网站/南宁百度网站推广

声明使用JDK8测试; 参考官网配置,网址如下: http://www.mybatis.org/mybatis-3/zh/configuration.html aggressiveLazyLoading 在mybatis版本小于3.4.1时候,默认是true开启状态,懒加载要有效果需要手动关闭&#xff1…...

郑州做网站那家好/百度sem运营

关于jQuery的链式调用真正有意义的链式调用也就是方法链(method chaining)。方法链这个词是有的,而且使用的很广泛。其实很多人口中的“链式调用”实际上就是指方法链。但是“链式调用”这个词语还可以描述函数调用链,所以让它自身的存在价值变得难以理解…...

网站 移动化/体验营销是什么

由于wildfly是jboss 8.x以上的版本,并且默认jdk配置要求是1.8以上,故在启动add-user.bat文件时,如果jdk版本过低,cmd会报异常,其中原因就有可能是jdk版本低造成。此时需要检查环境变量里JAVA_HOME的配置是否为jdk1.8的…...

深圳制作网站建设推广/沈阳seo排名优化推广

WIN2003 IIS最小权限分配.bat代码如下:echo off echo "虚拟主机C盘权限设定" echo "Author:an85.com" echo "删除C盘的everyone的权限" cd/ cacls "%SystemDrive%" /r "everyone" /e cacls "%SystemRoot%" /r &qu…...

百度糯米网站怎么做/电商怎么做新手入门

文章目录简介推荐参数1. 前置条件1.1 点到字符串的转换压缩未压缩混合形式1.2 密钥派生函数6. 加解密加密流程解密流程实现参考资料简介 国密SM2算法并不仅仅是提供了新的曲线参数,而是在算法上对ECC进行了修改。 SM2的曲线使用了Weierstrass模型: y2x…...