当前位置: 首页 > news >正文

Pytorch Advanced(二) Variational Auto-Encoder

自编码说白了就是一个特征提取器,也可以看作是一个降维器。下面找了一张很丑的图来说明自编码的过程。

自编码分为压缩和解码两个过程。从图中可以看出来,压缩过程就是将一组数据特征进行提取, 得到更深层次的特征。解码的过程就是利用之前的深层次特征再还原成为原来的数据特征。那么如何保证从压缩到解码两部分,原数据和解码数据保持一致呢?这就是要训练的过程。

如何理解降维?如果压缩的过程是卷积,维度可以根据核的个数变化,特征维度因此而改变。


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')sample_dir = 'samples'
if not os.path.exists(sample_dir):os.makedirs(sample_dir)
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3dataset = torchvision.datasets.MNIST(root='../../data',train=True,transform=transforms.ToTensor(),download=True)# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True)

模型搭建:这里搭建的是一个变分自编码,Variational Autoencoder

那么变分自编码是为了解决什么问题呢? ——- 其主要思想还是希望学习隐层变量,并将其用来表示原始数据,但是它加另一个条件, 即隐层变量能学习原始数据的分布, 并反过来生产一些和原始数据相似的数据(这有啥用?—-可用于图片修复,让图片按训练集的数据分布变化)。

变分自编码 (Variational Autoencoder) 为了让隐层抓住输入数据特性, 而不是简单的输出数据=输入数据,他在隐层中加入随机噪声(单位高斯噪声)(这个过程也叫reparametrize),以确保隐层能较好抽象输入数据特点。

代码中怎么做的呢?

1、编码过程中我们保存了第二层线性层的输出。其中第二层包含有fc2与fc3两部分,他们是并联的。

2、给隐藏层加入随机噪声,作为解码的输入

class VAE(nn.Module):def __init__(self, image_size=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.fc1 = nn.Linear(image_size, h_dim)self.fc2 = nn.Linear(h_dim, z_dim)self.fc3 = nn.Linear(h_dim, z_dim)self.fc4 = nn.Linear(z_dim, h_dim)self.fc5 = nn.Linear(h_dim, image_size)def encode(self, x):h = F.relu(self.fc1(x))return self.fc2(h), self.fc3(h)def reparameterize(self, mu, log_var):std = torch.exp(log_var/2)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = F.relu(self.fc4(z))return F.sigmoid(self.fc5(h))def forward(self, x):mu, log_var = self.encode(x)z = self.reparameterize(mu, log_var)x_reconst = self.decode(z)return x_reconst, mu, log_var

训练:由于训练中加入了噪声,所以损失值的结构也因此改变。一部分来源于解码内容核原内容的相似度,另一部分是kl_div,具体是什么意义需查看论文。

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Start training
for epoch in range(num_epochs):for i, (x, _) in enumerate(data_loader):# Forward passx = x.to(device).view(-1, image_size)x_reconst, mu, log_var = model(x)# Compute reconstruction loss and kl divergence# For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())# Backprop and optimizeloss = reconst_loss + kl_divoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 10 == 0:print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))with torch.no_grad():# Save the sampled imagesz = torch.randn(batch_size, z_dim).to(device)out = model.decode(z).view(-1, 1, 28, 28)save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))# Save the reconstructed imagesout, _, _ = model(x)x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

模型训练完成了之后该如何使用这个模型呢?

model.decode()是一个解码的过程,我们给他一个随机的中间特征z就可以输出一个数字图片了。

z = torch.randn(1,z_dim).to(device)
out = model.decode(z)
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

有了随机的一张图片之后,我们把他完整的放入模型中,生成了和输入相似的一张图片,也没看出来是修复了图像......

out,_,_ = model(out) 
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

相关文章:

Pytorch Advanced(二) Variational Auto-Encoder

自编码说白了就是一个特征提取器,也可以看作是一个降维器。下面找了一张很丑的图来说明自编码的过程。 自编码分为压缩和解码两个过程。从图中可以看出来,压缩过程就是将一组数据特征进行提取, 得到更深层次的特征。解码的过程就是利用之前的…...

Flask 使用 JWT(三)flask-jwt-extended

如果想要在 flask 中使用 JWT ,推荐使用 flask-jwt-extended 插件。 使用 pip 安装这个扩展插件的最简单方法是: pip install flask-jwt-extended基本使用 在接下来的案例中,我们看一下基本使用。我们可以使用 create_access_token() 函数用来生成实际的 JWT token。@jwt_r…...

堆与栈的区别

OVERVIEW 栈与堆的区别一、程序内存分区中的堆与栈1.栈2.堆3.堆&栈 二、数据结构中的堆与栈1.栈2.堆 三、堆的深入1.堆插入2.堆删除:3.堆建立:4.堆排序:5.堆实现优先队列:6.堆与栈的相关练习 栈与堆的区别 自整理,…...

OpenWrt kernel install分析(2)

一. 前言 接下来分析make -C image compile install TARGET_BUILD。 二. Makefile分析 1. 命令首先运行target/linux/mediatek/image/Makefile,该文件内容如下: target/linux/mediatek/image/Makefile: include $(TOPDIR)/rules.mk include $(INCLUDE_DIR)/image.…...

【计算机网络】传输层协议——TCP(下)

文章目录 1. 三次握手三次握手的本质是建立链接,什么是链接?整体过程三次握手过程中报文丢失问题为什么2次握手不可以?为什么要三次握手? 2. 四次挥手整体过程为什么要等待2MSL 3. 流量控制4. 滑动窗口共识滑动窗口的一般情况理解…...

Vue前端页面打印

前端依赖10-插件"print-js": “^1.6.0” 一:简介 print-js 是一个 Vue.js 插件,用于在 Vue.js 项目中实现打印功能。它依赖于 print-js 库,所以需要安装这个库。 能实现以下功能: PDF打印(默认&#xff…...

Visual Studio将C#项目编译成EXE可执行程序

经常看文章时会收获不少实用工具,有的在github上是编译好的,有的则是未编译的项目文件。所以经常会使用Visual Studio编译项目文件成exe可执行程序,以下为编译的流程。 第一步,从github上下载项目文件,举个例子&#…...

git把某一次commit修改过的文件打包导出(git)

1、使用命令把修改的文件打包导出:打包某次commit: git diff-tree -r --no-commit-id --name-only f4710c4a32975904b00609f3145c709f31392140 | xargs tar -rf xxx_1.1.tar 2、使用命令把某次节点后的文件导出: window 下: git diff f4710c4a32975904b00609f3145c709f31392…...

Vue3 Ajax(axios)异步

文章目录 Vue3 Ajax(axios)异步1. 基础1.1 安装Ajax1.2 使用方法1.3 浏览器支持情况 2. GET方法2.1 参数传递2.2 实例 3. POST方法4. 执行多个并发请求5. axios API5.1 传递配置创建请求5.2 请求方法的别名5.3 并发5.4 创建实例5.5 实例方法5.6 请求配置项5.7 响应结构5.8 配置…...

idea2023全量方法debug

为什么要全量debug 刚上手项目或者研读开源项目源码的时候,我们对项目的结构,尤其是功能链路非常陌生,想要debug根本不知道断点打在哪,光靠文件名类名或者方法名去猜也不是个事。这时候只要配置一下全量debug模式,就能…...

Docker镜像解析获取Dockerfile文件

01、概述 当涉及到容器镜像的安全时,特别是在出现镜像投毒引发的安全事件时,追溯镜像的来源和解析Dockerfile文件是应急事件处理的关键步骤。在这篇博客中,我们将探讨如何从镜像解析获取Dockerfile文件,这对容器安全至关重要。 02…...

使用maven命令打jar包

参考:https://blog.csdn.net/qq_27525611/article/details/123487255 https://blog.csdn.net/qq_35860138/article/details/82701919 小伙伴给我的项目自己尝试命令行打包遇到的坑,简单记录下 // 打包(1.8环境下打的,17会报错&…...

【多线程】死锁 详解

死锁 一. 死锁是什么二. 死锁的场景1. 一个线程一把锁2. 两个线程两把锁3. N 个线程 M 把锁 三. 死锁产生的四个必要条件四. 如何避免死锁 一. 死锁是什么 死锁是这样一种情形: 多个线程同时被阻塞,因为每个进程都在等其他线程释放某些资源,…...

成考[专升本政治]科目必背知识点

1. 马克思主义哲学研究的对象是:关于自然、社会、思维发展的一般规律。 2. 对待马克思主义的科学态度是:坚持和发展。 3. 物质的唯一特性是客观实在性。这里的客观实在是指:不以人的意志为转移。 4. 在实际工作中,要注意掌握…...

spring boot 使用AOP+自定义注解+反射实现操作日志记录修改前数据和修改后对比数据,并保存至日志表

一、添加aop starter依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-aop</artifactId> </dependency>二&#xff1a;自定义字段翻译注解。&#xff08;修改功能时&#xff0c;需要显示如…...

【深度学习】Pytorch 系列教程(二):PyTorch数据结构:1、Tensor(张量): GPU加速(GPU Acceleration)

目录 一、前言 二、实验环境 三、PyTorch数据结构 0、分类 1、张量&#xff08;Tensor&#xff09; 1. 维度&#xff08;Dimensions&#xff09; 2. 数据类型&#xff08;Data Types&#xff09; 3. GPU加速&#xff08;GPU Acceleration&#xff09; 一、前言 ChatGP…...

多线程|多进程|高并发网络编程

一.多进程并发服务器 多进程并发服务器是一种经典的服务器架构&#xff0c;它通过创建多个子进程来处理客户端连接&#xff0c;从而实现并发处理多个客户端请求的能力。 概念&#xff1a; 服务器启动时&#xff0c;创建主进程&#xff0c;并绑定监听端口。当有客户端连接请求…...

云计算——ACA学习 云计算分类

作者简介&#xff1a;一名云计算网络运维人员、每天分享网络与运维的技术与干货。 公众号&#xff1a;网络豆 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a; 网络豆的主页​​​​​ 目录 写在前面 前期回顾 本期介绍 一.云计算分类 1.公有云…...

3 分钟,带你了解低代码开发

一、低代码平台存在的意义 传统软件开发交付链中&#xff0c;需求经过3次传递&#xff0c;用户→业务→架构师→开发&#xff0c;每一层传递都可能使需求失真&#xff0c;导致最终交付的功能返工。 业务的变化促使软件开发过程不断更新、迭代和演进&#xff0c;而低代码开发即是…...

小白学Unity03-太空漫游游戏脚本,控制飞船移动旋转

首先搭建好太阳系以及飞机的场景 需要用到3个脚本 1.控制飞机移动旋转 2.控制摄像机LookAt朝向飞机和差值平滑跟踪飞机 3.控制各个星球自转以及围绕太阳旋转&#xff08;rotate()和RotateAround()&#xff09; 1.控制飞机移动旋转的脚本 using System.Collections; using…...

19c补丁后oracle属主变化,导致不能识别磁盘组

补丁后服务器重启&#xff0c;数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后&#xff0c;存在与用户组权限相关的问题。具体表现为&#xff0c;Oracle 实例的运行用户&#xff08;oracle&#xff09;和集…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;百货中心供应链管理系统被用户普遍使用&#xff0c;为方…...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版​分享

平时用 iPhone 的时候&#xff0c;难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵&#xff0c;或者买了二手 iPhone 却被原来的 iCloud 账号锁住&#xff0c;这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...

定时器任务——若依源码分析

分析util包下面的工具类schedule utils&#xff1a; ScheduleUtils 是若依中用于与 Quartz 框架交互的工具类&#xff0c;封装了定时任务的 创建、更新、暂停、删除等核心逻辑。 createScheduleJob createScheduleJob 用于将任务注册到 Quartz&#xff0c;先构建任务的 JobD…...

【SQL学习笔记1】增删改查+多表连接全解析(内附SQL免费在线练习工具)

可以使用Sqliteviz这个网站免费编写sql语句&#xff0c;它能够让用户直接在浏览器内练习SQL的语法&#xff0c;不需要安装任何软件。 链接如下&#xff1a; sqliteviz 注意&#xff1a; 在转写SQL语法时&#xff0c;关键字之间有一个特定的顺序&#xff0c;这个顺序会影响到…...

页面渲染流程与性能优化

页面渲染流程与性能优化详解&#xff08;完整版&#xff09; 一、现代浏览器渲染流程&#xff08;详细说明&#xff09; 1. 构建DOM树 浏览器接收到HTML文档后&#xff0c;会逐步解析并构建DOM&#xff08;Document Object Model&#xff09;树。具体过程如下&#xff1a; (…...

Nuxt.js 中的路由配置详解

Nuxt.js 通过其内置的路由系统简化了应用的路由配置&#xff0c;使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序

一、开发环境准备 ​​工具安装​​&#xff1a; 下载安装DevEco Studio 4.0&#xff08;支持HarmonyOS 5&#xff09;配置HarmonyOS SDK 5.0确保Node.js版本≥14 ​​项目初始化​​&#xff1a; ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

爬虫基础学习day2

# 爬虫设计领域 工商&#xff1a;企查查、天眼查短视频&#xff1a;抖音、快手、西瓜 ---> 飞瓜电商&#xff1a;京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空&#xff1a;抓取所有航空公司价格 ---> 去哪儿自媒体&#xff1a;采集自媒体数据进…...

DeepSeek 技术赋能无人农场协同作业:用 AI 重构农田管理 “神经网”

目录 一、引言二、DeepSeek 技术大揭秘2.1 核心架构解析2.2 关键技术剖析 三、智能农业无人农场协同作业现状3.1 发展现状概述3.2 协同作业模式介绍 四、DeepSeek 的 “农场奇妙游”4.1 数据处理与分析4.2 作物生长监测与预测4.3 病虫害防治4.4 农机协同作业调度 五、实际案例大…...