当前位置: 首页 > news >正文

基于知识蒸馏的两阶段去雨去雪去雾模型学习记录(三)之知识测试阶段与评估模块

去雨去雾去雪算法分为两个阶段,分别是知识收集阶段与知识测试阶段,前面我们已经学习了知识收集阶段,了解到知识阶段的特征迁移模块(CKT)与软损失(SCRLoss),那么在知识收集阶段的主要重点便是HCRLoss(硬损失),事实上,知识测试阶段要比知识收集阶段简单,因为这个模块只需要训练学生网络即可。

模型创新点

在进行知识测试阶段的代码学习之前,我们来回顾一下去雨去雪去雾网络的创新点:
首先是提出两阶段的知识蒸馏网络,即构建三个教师网络与一个学生网络,设置总训练次数为250,其中前125个epoch教师网络与学生网络一同训练,这里的训练是指将图像输入教师网络,随后将教师网络的输出结果与中间特征图保留,将其作为真值指导学生网络进行训练。
其次便是提出知识迁移模块(CKT)该模块的作用是将教师网络的特征迁移到学生网络。
随后便是软损失与硬损失计算了,这个其实是知识蒸馏中的概念。
总体来看去雨去雾去雪网络的设计虽然较为新颖,但事实上就是知识蒸馏网络的架构,本着这一点,程序理解起来也就容易多了。

在这里插入图片描述

接下来开始代码的学习:

小插曲(算力不足)

首先需要指出,前面将batch-size设置为4,但却会报错:

RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED

开始时博主以为是cuDNN与CUDA版本不匹配导致的,但后来一想不对呀,先前已经运行过呀,那么问题很可能便是batch出问题了,果然将batch改为3后就正常了,这是由于算力不足导致的,注意算力不足和显存不足还是有区别的。
将batch-size改为3后重新运行,开始知识测试阶段的探索。

知识测试阶段

事实上,知识测试阶段的实现与知识收集阶段几乎相同,并且要比知识收集阶段简单,其只是训练学生网络,并计算一个硬损失而已。
由于知识测试阶段与知识收集阶段几乎相同,因此有许多地方是重复的,这里博主便会简要介绍。
首先相同的是使用train_loader进行训练集的加载,并使用tqdm进行封装。
随后便是遍历过程,这个过程就要简单很多了,没有使用到教师网络,直接将图像输入学生网络进行预测即可,这里的学生网络与教师网络的构造是完全相同的,将结果分别计算L1损失与HCR_loss即可。不过需要注意的是由于该阶段不需要与教师网络进行特征迁移,因此就不需要返回中间特征图了,即设置return_feat=False

for target_images, input_images in pBar:if target_images is None: continuetarget_images = target_images.cuda()input_images = torch.cat(input_images).cuda()preds = model(input_images, return_feat=False)G_loss = criterion_l1(preds, target_images)HCR_loss = 0.2 * criterion_hcr(preds, target_images, input_images)total_loss = G_loss + HCR_loss

至于其他的基本就相同了,需要注意的是这里的batch设置为3。接下来记录一下数据的变化情况:

input_images:输入图像,torch.Size([3, 3, 224, 224])第一个3是指图像数量,第二个3是指通道维度
target_images:目标图像(真值),torch.Size([3, 3, 224, 224])第一个3是指图像数量,第二个3是指通道维度
preds:预测图像(去噪后的图像),torch.Size([3, 3, 224, 224])第一个3是指图像数量,第二个3是指通道维度

在这里插入图片描述

随后计算L1损失与HCRLoss,由于在学生网络中使用的事实上是混合数据集,即不区分去噪类型,因此输入图像等都是直接使用tesnor格式,而非list格式。

G_loss:tensor(0.5621, device='cuda:0', grad_fn=<L1LossBackward>)

HCRLoss

SCRLoss相同,HCRLoss也是先将图像进行特征转换后再计算损失的

HCRLoss((vgg): Vgg19((slice1): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True))(slice2): Sequential((2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True))(slice3): Sequential((7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True))(slice4): Sequential((12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(17): ReLU(inplace=True)(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True))(slice5): Sequential((21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(24): ReLU(inplace=True)(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(26): ReLU(inplace=True)(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)))(l1): L1Loss()
)
HCRLoss:tensor(0.3274, device='cuda:0', grad_fn=<MulBackward0>)

评估模块

至此,知识测试阶段便完成了,随后便是模型评估了。这里默认设置评估时的batch-size为1,即每次输入一张图像。
所谓的评估指的是对学生网络的评估,该模块其实与知识测试阶段类似,不同之处在于这里是需要计算SSIMPSNR的。至于其他则是完全相同,核心代码如下:

for target, image in pBar:if torch.cuda.is_available():image = image.cuda()target = target.cuda()pred = model(image)   		psnr_list.append(torchPSNR(pred, target).item())ssim_list.append(pytorch_ssim.ssim(pred, target).item())

由于batch-size设置为1,因此targettorch.Size([1, 3, 480, 640])image也为torch.Size([1, 3, 480, 640]),这里需要注意的是,在训练阶段(包含知识收集与知识测试阶段),数据集中的图像都要转换为224x224的大小,而在评估阶段则不需要进行转换了,即使用的是原图像的大小。
直接将输入图输入模型,获的去噪后的图像pred大小为torch.Size([1, 3, 480, 640])

pred = model(image)

在这里插入图片描述

随后将预测图像与真值图像进行计算PSNR与SSIM

psnr_list.append(torchPSNR(pred, target).item())
ssim_list.append(pytorch_ssim.ssim(pred, target).item())

PSNR计算

@torch.no_grad()
def torchPSNR(prd_img, tar_img):if not isinstance(prd_img, torch.Tensor):prd_img = torch.from_numpy(prd_img)tar_img = torch.from_numpy(tar_img)imdff = torch.clamp(prd_img, 0, 1) - torch.clamp(tar_img, 0, 1)rmse = (imdff**2).mean().sqrt()ps = 20 * torch.log10(1/rmse)return ps

SSIM计算

class SSIM(torch.nn.Module):def __init__(self, window_size = 11, size_average = True):super(SSIM, self).__init__()self.window_size = window_sizeself.size_average = size_averageself.channel = 1self.window = create_window(window_size, self.channel)def forward(self, img1, img2):(_, channel, _, _) = img1.size()if channel == self.channel and self.window.data.type() == img1.data.type():window = self.windowelse:window = create_window(self.window_size, channel)            if img1.is_cuda:window = window.cuda(img1.get_device())window = window.type_as(img1)        self.window = windowself.channel = channelreturn _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):(_, channel, _, _) = img1.size()window = create_window(window_size, channel)  if img1.is_cuda:window = window.cuda(img1.get_device())window = window.type_as(img1)return _ssim(img1, img2, window, window_size, channel, size_average)

将每个循环得到的psnrssim加入列表

在这里插入图片描述
最后的PSNRSSIM是对list中的所有值求平均:

print("PSNR: {:.3f}".format(np.mean(psnr_list)))
print("SSIM: {:.3f}".format(np.mean(ssim_list)))

至此,知识测试阶段与评估模块就讲解完成了,接下来博主将对该模型进行改进。

相关文章:

基于知识蒸馏的两阶段去雨去雪去雾模型学习记录(三)之知识测试阶段与评估模块

去雨去雾去雪算法分为两个阶段&#xff0c;分别是知识收集阶段与知识测试阶段&#xff0c;前面我们已经学习了知识收集阶段&#xff0c;了解到知识阶段的特征迁移模块&#xff08;CKT)与软损失&#xff08;SCRLoss&#xff09;,那么在知识收集阶段的主要重点便是HCRLoss(硬损失…...

代码随想录二刷day46

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、力扣139. 单词拆分二、力扣动态规划&#xff1a;关于多重背包&#xff0c;你该了解这些&#xff01; 前言 提示&#xff1a;以下是本篇文章正文内容&#x…...

计算机竞赛 行人重识别(person reid) - 机器视觉 深度学习 opencv python

文章目录 0 前言1 技术背景2 技术介绍3 重识别技术实现3.1 数据集3.2 Person REID3.2.1 算法原理3.2.2 算法流程图 4 实现效果5 部分代码6 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 深度学习行人重识别(person reid)系统 该项目…...

在线图片转BASE64、在线BASE64转图片

图片转BASE64、BASE64转图片...

什么是RPA?一文了解RPA发展与进程!

RPA&#xff08;Robotic Process Automation&#xff0c;机器人流程自动化&#xff09;是一种通过软件机器人模拟人类在计算机上执行重复性任务的技术。RPA的核心理念是将规则、过程和数据“机器人化”&#xff0c;从而实现对业务流程的自动化。RPA技术可以显著提高企业的工作效…...

【云备份项目】【Linux】:环境搭建(g++、json库、bundle库、httplib库)

文章目录 1. g 升级到 7.3 版本2. 安装 jsoncpp 库3. 下载 bundle 数据压缩库4. 下载 httplib 库从 Win 传输文件到 Linux解压缩 1. g 升级到 7.3 版本 &#x1f517;链接跳转 2. 安装 jsoncpp 库 &#x1f517;链接跳转 3. 下载 bundle 数据压缩库 安装 git 工具 sudo yum…...

工信部教考中心:什么是《研发效能(DevOps)工程师》认证,拿到证书之后有什么作用!(下篇)丨IDCF

拿到证书有什么用&#xff1f; 提高职业竞争力&#xff1a;通过学习认证培训课程可以提升专业技能&#xff0c;了解项目或产品研发全生命周期的核心原则&#xff0c;掌握端到端的研发效能提升方法与实践&#xff0c;包括组织与协作、产品设计与运营、开发与交付、测试与安全、…...

Linux进程相关管理(ps、top、kill)

目录 一、概念 二、查看进程 1、ps命令查看进程 1&#xff09;ps显示某个时间点的程序运行情况 2&#xff09;查看指定的进程信息 2、top命令查看进程 1&#xff09;信息统计区&#xff1a; 2&#xff09;进程信息区 3&#xff09;交互式命令 三、信号控制进程 四、…...

微服务技术栈-Ribbon负载均衡和Nacos注册中心

文章目录 前言一、Ribbon负载均衡1.LoadBalancerInterceptor&#xff08;负载均衡拦截器&#xff09;2.负载均衡策略IRule 二、Nacos注册中心1.Nacos简介2.搭建Nacos注册中心3.服务分级存储模型4.环境隔离5.Nacos与Eureka的区别 总结 前言 在上面那个文章中介绍了微服务架构的…...

知识图谱和大语言模型的共存之道

源自&#xff1a;开放知识图谱 “人工智能技术与咨询” 发布 导 读 01 知识图谱和大语言模型的历史 图1 图2 图3 图4 图5 02 知识图谱和大语言模型作为知识库的优缺点 图6 图7 表1 表2 图8 图9 03 知识图谱和大语言模型双知识平台融合 图10 图11 04 总结与展望 声明:公众号转…...

enum, sizeof, typedef

枚举类型enum enum 是 C 语言中的一种自定义类型enum 值是可以根据需要自定义的整型值第一个定义的 enum 值默认为 0默认情况下的 enum 值在前一个定义值得基础上加 1enum 类型的变量只能取定义时得离散值 void code() {enum Color{GREEN, // 0RED 2, // 2BLUE, …...

(二)激光线扫描-相机标定

1. 何为相机标定? 当相机拍摄照片时,我们看到的图像通常与我们实际看到的不完全相同。这是由相机镜头引起的,而且发生的频率比我们想象的要高。 这种图像的改变就是我们所说的畸变。一般来说,畸变是指直线在图像中出现弯曲或弯曲。 这种畸变我们可以通过相机标定来进行解…...

pytorch 数据载入

在PyTorch中&#xff0c;数据载入是训练深度学习模型的重要一环。 本文将介绍三种常用的数据载入方式&#xff1a;Dataset、DataLoader、以及自定义的数据加载器。 使用 Dataset 载入数据 方法&#xff1a; from torch.utils.data import Datasetclass CustomDataset(Dataset…...

angular 在vscode 下的hello world

Angulai 是google 公司开发的前端开发框架。Angular 使用 typescript 作为编程语言。typescript 是Javascript 的一个超集&#xff0c;提升了某些功能。本文介绍运行我的第一个angular 程序。 前面部分参考&#xff1a; Angular TypeScript Tutorial in Visual Studio Code 一…...

Django、Nginx、uWSGI详解及配置示例

一、Django、Nginx、uWSGI的概念、联系与区别 Django、Nginx 和 uWSGI 都是用于构建和运行 Web 应用程序的软件&#xff0c;这三个软件的概念如下&#xff1a; Django&#xff1a;Django 是一个基于 Python 的开源 Web 框架&#xff0c;它提供了一套完整的工具和组件&#xf…...

王道考研计算机组成原理——计算机硬件的基础知识

计算机组成原理的基本概念 计算机硬件的针脚都是用来传递信息&#xff0c;传递数据用的&#xff1a; 服务程序包含一些调试程序&#xff1a; 计算机硬件的基本组成 控制器通过电信号来协调其他部件的工作&#xff0c;同时负责解析存储器里存放的程序指令&#xff0c;然后指挥…...

[晕事]今天做了件晕事21;设置代理访问网站的时候需注意的问题

今天在家上班&#xff0c;设置好VPN&#xff0c;通过代理来访问公司内部的一个系统浏览器的反应如下&#xff1a; Hmmm… can’t reach this page ***.com refused to connect. 这个返回的错误&#xff0c;非常的具有迷惑性&#xff0c;提示的意思&#xff1a;拒绝链接&#xf…...

Go通过reflect.Value修改值

到目前为止&#xff0c;反射还只是程序中变量的另一种读取方式。然而&#xff0c;在本节中我们将重点讨论如何通过反射机制来修改变量。 回想一下&#xff0c;Go语言中类似x、x.f[1]和*p形式的表达式都可以表示变量&#xff0c;但是其它如x 1和f(2)则不是变量。一个变量就是一…...

【MySql】4- 实践篇(二)

文章目录 1. SQL 语句为什么变“慢”了1.1 什么情况会引发数据库的 flush 过程呢&#xff1f;1.2 四种情况性能分析1.3 InnoDB 刷脏页的控制策略 2. 数据库表的空间回收2.1 innodb_file_per_table参数2.2 数据删除流程2.3 重建表2.4 Online 和 inplace 3. count(*) 语句怎样实现…...

获取多个接口的数据并进行处理,使用Promise.all来等待所有接口请求完成

Promise.all (等待机制) 方法 它调用了多个函数&#xff0c;这些函数返回了Promise对象&#xff0c;每个Promise对象代表了一个异步操作。 然后&#xff0c;使用Promise.all将这多个Promise对象包装成一个新的Promise对象&#xff0c;它会等待所有的Promise都完成&#xff08;或…...

利用C++开发一个迷你的英文单词录入和测试小程序-升级版本

我们现在有了一个本地sqlite3的迷你英文单词小测试工具&#xff0c;需求就跟工作当中一样是不断变更的。这里虚构两个场景&#xff0c;并且一步一步的完成最终升级后的小demo。 场景&#xff1a;数据不依赖本地sqlite3&#xff0c;需要支持远程访问&#xff0c;用目前的restfu…...

用c动态数组(实现权重矩阵可视化)实现手撸神经网络230902

变量即内存、指针使用的架构原理: 1、用结构struct记录 网络架构,如 float*** ws 为权重矩阵的指针(指针地址); 2、用 = (float*)malloc (Num * sizeof(float)) 给 具体变量分配内存; 3、用 = (float**)malloc( Num* sizeof(float*) ) 给 指向 具体变量(一维数组)的…...

Android.mk和Android.bp

公司承接Android、iOS等APP开发、前后端网站开发、小程序开发、安全服务等项目&#xff01; 公司官网:www.bincodesec.com 项目案例 一、编译不同类型的模块 1.编译成Java库 Android.mk include $(BUILD_JAVA_LIBRARY)Android.bp java_library {} 2.编译成Java静态库 And…...

CSS 常用样式-文本属性

一、水平对齐 text-align CSS中的text-align属性用于水平对齐文本。它可以应用于块级元素和表格单元格。 常见的属性值包括&#xff1a; left&#xff1a;左对齐&#xff0c;文本在容器的左侧。right&#xff1a;右对齐&#xff0c;文本在容器的右侧。center&#xff1a;居中…...

BootstrapBlazor企业级组件库:前端开发的革新之路

作为一名Web开发人员&#xff0c;开发前端我们一般都是使用JavaScript&#xff0c;而Blazor就是微软推出的基于.Net平台交互式客户Web UI 框架&#xff0c;可以使用C#替代JavaScript&#xff0c;减少我们的技术栈、降低学习前端的成本。 而采用Blazor开发&#xff0c;少不了需…...

力扣 -- 1745. 分割回文串 IV

解题步骤&#xff1a; 参考代码&#xff1a; class Solution { public:bool checkPartitioning(string s) {int ns.size();vector<vector<bool>> dp(n,vector<bool>(n));for(int in-1;i>0;i--){for(int ji;j<n;j){if(s[i]s[j]){dp[i][j]i1<j?dp[i…...

C# 给某个方法设定执行超时时间

C# 给某个方法设定执行超时时间在某些情况下(例如通过网络访问数据)&#xff0c;常常不希望程序卡住而占用太多时间以至于造成界面假死。 在这时、我们可以通过Thread、Thread Invoke&#xff08;UI&#xff09;或者是 delegate.BeginInvoke 来避免界面假死&#xff0c; 但是…...

安装NodeJS并使用yarn下载前端依赖

文章目录 1、安装NodeJS1.1 下载NodeJS安装包1.2 解压并配置NodeJS1.3 验证是否安装成功2、使用yarn下载前端依赖2.1 安装yarn2.2 使用yarn下载前端依赖参考目标:在Windows下安装新版NodeJS,并使用yarn下载前端依赖,实现运行前端项目。 1、安装NodeJS 1.1 下载NodeJS安装包…...

(Java高级教程)第三章Java网络编程-第八节:博客系统搭建(前后端分离)

文章目录 一&#xff1a;前端页面回顾二&#xff1a;博客功能展示三&#xff1a;数据库表设计&#xff08;1&#xff09;表设计&#xff08;2&#xff09;封装DataSource 四&#xff1a;实体类和数据访问对象&#xff08;1&#xff09;实体类&#xff08;2&#xff09;数据访问…...

901. 股票价格跨度

设计一个算法收集某些股票的每日报价&#xff0c;并返回该股票当日价格的 跨度 。 当日股票价格的 跨度 被定义为股票价格小于或等于今天价格的最大连续日数&#xff08;从今天开始往回数&#xff0c;包括今天&#xff09;。 例如&#xff0c;如果未来 7 天股票的价格是 [100,…...

视觉传达设计考研/seo百度关键词优化

Linux网络编程--4. 完整的读写函数来源:http://linuxc.51.net 作者:hoyt(2001-05-08 11:20:52)一旦我们建立了连接,我们的下一步就是进行通信了.在Linux下面把我们前面建立的通道 看成是文件描述符,这样服务器端和客户端进行通信时候,只要往文件描述符里面读写东西了. 就象我们…...

合肥婚恋网站建设/网络营销平台有哪些?

类似问题答案北京交通大学计算机类专业2016年在黑龙江理科高考录取最低分数线学校 地 区 专业 年份 批次 类型 分数 北京交通大学 黑龙江 计算机类 2016 一批 理科 625 学校 地 区 专业 年份 批次 类型 分数 北京交通大学 黑龙江 计算机类 2016 一批 理科 625 北京交通大学 黑龙…...

hexo发布wordpress/国内搜索引擎优化的公司

mysql 字段使用as在mysql中&#xff0c;select查询可以使用AS关键字为查询的字段起一个别名&#xff0c;该别名用作表达式的列名&#xff0c;并且别名可以在GROUP BY&#xff0c;ORDER BY或HAVING等语句中使用。例如&#xff1a;SELECT CONCAT(last_name,, ,first_name) AS ful…...

自己怎么开网站备案/合理使用说明

一、背景现在使用基于Git 作为开发项目的管理工具已经非常普遍&#xff0c;很多与Git相关的平台工具的基本配置和使用方法都类似&#xff0c;现主要总结一下Git的基本配置&#xff0c;教你如何从已经存在项目的Git上clone代码到本地。使用操作系统&#xff1a;Mac OS二、添加和…...

asp动态网站开发软件/站长基地

} else { // 夜间模式 tv.setTextColor(R.color.skinPrimaryTextColor_Dark); } } 这种实现并非一无是处&#xff0c;从实现的难度而言&#xff0c;至少能够保护开发者为数不多的发囊。 当然&#xff0c;这种方案有「优化空间」&#xff0c;比如提供封装的工具方法 看似摆…...

wordpress数据库连接文件/广州seo快速排名

1.k8s高可用架构解析 2.基本环境配置 Kubeadm安装方式自1.14版本以后,安装方法几乎没有任何变化,此文档可以尝试安装最新的k8s集群,centos采用的是7.x版本 K8S官网:https://kubernetes.io/docs/setup/ 最新版高可用安装:https://kubernetes.io/docs/setup/production-e…...