当前位置: 首页 > news >正文

用VAE生成图像

用VAE生成图像

  • 自编码器AE,auto-encoder
  • VAE
    • 讲讲为什么是log_var
    • 为什么要用重参数化技巧
  • 用VAE生成图像

变分自编码器是自编码器的改进版本,自编码器AE是一种无监督学习,但它无法产生新的内容,变分自编码器对其潜在空间进行拓展,使其满足正态分布,情况就大不一样了。

自编码器AE,auto-encoder

自编码器是通过对输入X进行编码后得到一个低维的向量z,然后根据这个向量还原出输入X。通过对比X与X∼\overset{\sim}{X}X的误差,再利用神经网络去训练使得误差逐渐减小,从而达到非监督学习的目的。
下图为AE的架构图:
在这里插入图片描述
自编码器不能随意产生合理的潜在变量,从而导致它无法产生新的内容。因为潜在变量Z都是编码器从原始图片中产生的。为了解决这一问题,研究人员对潜在空间Z(潜在变量对应的空间) 增加一些约束,使 Z 满足正态分布,由此就出现了VAE模型, VAE对编码器添加约束,就是强迫它产生服从单位正态分布的潜在变量。正是这种约束,把VAE和 AE 区分开来。

VAE

变分自编码器关键一点就是增加一个对潜在空间Z的正态分布约束,如何确定这个正态分布就成为主要目标,我们知道要确定正态分布,只要确定其两个参数: 均值μ\muμ和标准差σ\sigmaσ。那如何确定 μ,σ\mu, \sigmaμ,σ呢?用一般的方法或估计比较麻烦效果也不好,研究人员发现**用神经网络去拟合,简单效果也不错。**下图为VAE的架构图:
在这里插入图片描述
上图中,模块①的功能把输入样本X通过编码器输出两个m维向量(μ,log_var\mu, \mathrm{log\_var}μ,log_var), 这两个向量是潜在空间(假设满足正态分布)的两个参数(相当于均值和方差)。那么如何从这个潜在空间采样一个点 Z ?

这里假设潜在正态分布能生成输入图像,从标准正态分布 N(0, 1)中采样一个 ϵ\epsilonϵ(模块②的功能), 然后使
Z=μ+elog_var∗ϵZ=\mu + e^{log\_var}*\epsilonZ=μ+elog_varϵ
这也是模块③的主要功能。
Z是从潜在空间抽取的一个向量,Z通过解码器生成一个样本X∼\overset{\sim}{X}X, 这是模块④的功能。这里的 ϵ\epsilonϵ 是随机采样的,这就可以保证潜在空间的连续性,良好的结构性。而这些特性使得潜在空间的每个方向都表示数据中有意义的变化方向。

以上的这些步骤构成整个网络的前向传播过程,那反向传播应如何进行?要确定反向传播就会设计损失函数,损失函数是衡量模型优劣的主要指标。这里我们需要从以下两个方面进行衡量。
1)生成的新图像与原图像的相似度;
2)隐含空间的分布与正态分布的相似度。

度量图像的相似度一般采用交叉熵(如nn.BCELoss) , 度量两个分布的相似度一般采用KL散度(Kullback-Leibler divergence)。这两个度量的和构成了整个模型的损失函数。

以下是损失函数的具体代码,VAE损失函数的推导过程可以参考原论文

# 定义重构损失函数及KL散度
reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
kl_div = -0.5*torch.sum(1+log_var-mu.pow(2)-log_var.exp())
# 两者相加得总损失
loss = reconst_loss + kl_div

讲讲为什么是log_var

这里可以看成 log_var = log⁡σ\log \sigmalogσ,所以Z=μ+elog_var∗ϵZ=\mu + e^{log\_var}*\epsilonZ=μ+elog_varϵ也就是Z=μ+σ∗ϵZ=\mu + \sigma*\epsilonZ=μ+σϵ
其中ϵ∼N(0,1)\epsilon\sim N(0,1)ϵN(0,1), 这里涉及到重参数化reparameterization。

为什么要用重参数化技巧

如果想从高斯分布N(μ,σ2)N(\mu,\sigma^{2})N(μ,σ2)中采样,可以先从标准分布N(0,1)N(0,1)N(0,1)采样出 ϵ\epsilonϵ , 再得到 Z=σ∗ϵ+μZ = \sigma*\epsilon+\muZ=σϵ+μ.
这样做的好处是

  1. 如果直接对N(μ,σ2)N(\mu,\sigma^{2})N(μ,σ2)进行采样得到Z,则Z无法对μ,σ\mu,\sigmaμ,σ进行求偏导

  2. 将随机性转移到了 ϵ\epsilonϵ 这个常量上,而 σ\sigmaσμ\muμ则当做仿射变换网络的一部分,这样得到的Z=σ∗ϵ+μZ = \sigma*\epsilon+\muZ=σϵ+μ,则Z就可以对μ,σ\mu,\sigmaμ,σ进行求偏导来计算损失函数,进行求梯度,进行BP。

用VAE生成图像

下面将结合代码,用pytorch实现,为便于说明起见,数据集采用MNIST,整个网络架构如下图所示。
VAE网络架构图

# 1. 导入必要的包
import os 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision 
from torchvision import transforms 
from torchvision.utils import save_image# 2. 定义一些超参数
image_size = 784 # 28*28 
h_dim = 400 
z_dim = 20 
num_epochs = 30 
batch_size = 128 
learning_rate = 0.001 # 如果没有文件夹就创建一个文件夹
sample_dir = 'samples'
if not os.path.exists(sample_dir):os.makedirs(sample_dir)
  1. 对数据集进行预处理,如转换为Tensor, 把数据集转换为循环,可批量加载的数据集
# 只下载训练数据集即可
# 下载MNIST训练集
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(),download=True)# 数据加载
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
  1. 构建VAE模型,主要由Encoder和Decoder两部分组成

# 定义AVE模型
class VAE(nn.Module):def __init__(self, image_size=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.fc1 = nn.Linear(image_size, h_dim)self.fc2 = nn.Linear(h_dim, z_dim)self.fc3 = nn.Linear(h_dim, z_dim)self.fc4 = nn.Linear(z_dim, h_dim)self.fc5 = nn.Linear(h_dim, image_size)def encoder(self, x):h = F.relu(self.fc1(x))return self.fc2(h), self.fc3(h)# 用mu, log_var生成一个潜在空间点z, mu, log_var为两个统计参数,我们假设# 这个假设分布能生成图像def reparameterize(self, mu, log_var):std = torch.exp(log_var/2)eps = torch.randn_like(std)return mu + eps * std def decoder(self, z):h = F.relu(self.fc4(z))return F.sigmoid(self.fc5(h))def forward(self, x):mu, log_var = self.encoder(x)z = self.reparameterize(mu, log_var)x_reconst = self.decoder(z)return x_reconst, mu, log_var 
  1. 选择GPU和优化器
# 选择GPU和优化器
torch.cuda.set_device(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  1. 训练模型,同时保存原图像与随机生成的图像
# 训练模型,同时保存原图像与随机生成的图像
for epoch in range(num_epochs):for i, (x, _) in enumerate(data_loader):# 获取样本,并前向传播x = x.to(device).view(-1, image_size)x_reconst, mu, log_var = model(x)# 计算重构损失和KL散度(KL散度用于衡量两种分布的相似程度)# KL散度的计算可以参考https://shenxiaohai.me/2018/10/20/pytorch-tutorial-advanced-02/reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())# 反向传播和优化loss = reconst_loss + kl_div optimizer.zero_grad()loss.backward()optimizer.step()if (i+1)%100 == 0:print('Epoch [{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div:{:.4f}'.format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))# 利用训练的模型进行测试with torch.no_grad():# 保存采样图像,即潜在向量z通过解码器生成的新图像# 随机生成的图像z = torch.randn(batch_size, z_dim).to(device)out = model.decoder(z).view(-1, 1, 28, 28)save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))# 保存重构图像,即原图像通过解码器生成的图像out, _, _ = model(x)x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))
  1. 展示原图像及重构图像
#显示图片
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg 
import numpy as np recons_path = './samples/reconst-30.png'
Image = mpimg.imread(recons_path)
plt.imshow(Image)
plt.axis('off')
plt.show()
# reconst
# 奇数列为原图像,欧数列为原图像重构的图像,可以看出重构效果还不错。

在这里插入图片描述
8. 由潜在空间通过解码器生成的新图像,这个图像效果也不错


# sampled
# 为由潜在空间通过解码器生成的新图像,这个图像效果也不错
genPath = './samples/sampled-30.png'
Image = mpimg.imread(genPath)
plt.imshow(Image)
plt.axis('off')
plt.show()

在这里插入图片描述
总结:这里构建网络主要用全连接层,有兴趣的读者,可以把卷积层,如果编码层使用卷积层(如nn.Conv2d), 则解码器就需要使用反卷积层(如nn.ConvTranspose2d)。

相关文章:

用VAE生成图像

用VAE生成图像自编码器AE,auto-encoderVAE讲讲为什么是log_var为什么要用重参数化技巧用VAE生成图像变分自编码器是自编码器的改进版本,自编码器AE是一种无监督学习,但它无法产生新的内容,变分自编码器对其潜在空间进行拓展&#…...

你只会说MVC模型是什么但是不会实现?今天带你走通Web、Servlet、MVC、SpringMVC。代码演示很清晰

文章目录HTTP请求和HTTP响应从0手写一个Web服务器,看看能有多累人使用Servlet实现一个服务器,看看多简单Serlvet的创建Servlet的运行Servlet的其他问题Servlet这么爽,我们简单地探索一下它的原理JSP跟Servlet合作啦,我们来看一下他…...

C++中邻接矩阵、邻接表、链式前向星具体用法及讲解

图论在提高组中几乎占据半壁江山,而今天要讲的就是如何存储一个图一.邻接矩阵原理要建立一个图,根本的要素就是边和点而想要让计算机存储边和点就需要用到一些数据结构邻接矩阵是最简单的他使用了一个二维数组,来表示一个图假设数组名为map那…...

appium的安装详解

安装appium 爬虫手机APP需要实现自动化,所以要使用appnium来实现点击,输入,滑动等操作。由于appnium的安装较为繁琐,所以特意整理一篇文章来展示安装的详细过程过程中。 安装appnium共有3个步骤 安装 Android SDK安装 JDK安装 …...

STM32之 串口

串口通信串行接口简称串口,也称串行通信接口或串行通讯接口(通常指COM接口),是采用串行通信方 式的扩展接口。串行接口(Serial Interface)是指数据一位一位地顺序传送。其特点是通信线路简 单,只…...

CSDN 编程竞赛三十三期题解

竞赛总览 CSDN 编程竞赛三十三期题解:比赛详情 (csdn.net) 竞赛题解 题目1、奇偶排序 给定一个存放整数的数组,重新排列数组使得数组左边为奇数,右边为偶数(奇数和偶数的顺序根据输入的数字顺序排列)。 第七期竞赛…...

逆向练习之 mingyue.exe wp

目录 一.查壳 二.主函数 三.operate函数 四.storage函数及4618和4620指针功能的解释 五.judge函数 六.求解flag 七.其他--ida字符识别问题 一.查壳 64位无壳 二.主函数 1.这里的pointer_4618和4620是两个相邻的八字节内存单元,其中4620是字符串链表表头head 2.puts和s…...

LeetCode 热题 HOT 100 Java 题解 -- Part 3

练习地址 Part 1 : https://blog.csdn.net/qq_41080854/article/details/128829494 Part 2 : https://blog.csdn.net/qq_41080854/article/details/129278336 LeetCode 热题 HOT 100 Java 题解 -- Part 376. 最佳买卖股票时机含冷冻期77. 戳气球78. 零钱兑换79. 打家劫舍 III…...

QML键盘事件

在QML中,当有一个按键按下或释放时,会产生一个键盘事件,将其传递给获得有焦点的QML项目(讲focus属性设置为true,则获得焦点)。 按键处理的基本流程: Qt接收密钥操作并生成密钥事件。如果 QQuic…...

跨域问题怎么解决

解决跨域,原因:域名不同,域名相同端口不同;二级域名不同 什么是跨域? 就是两个项目之间通讯,如果访问的域名与ajax访问的地址不一致情况,默认情况浏览器有一个安全机制。 postman不一定能测试…...

微服务网关Gateway和Zuul的区别

spring-cloud-Gateway是spring-cloud的一个子项目。而zuul则是netflix公司的项目,只是spring将zuul集成在spring-cloud中使用而已。 因为zuul2.0连续跳票和zuul1的性能表现不是很理想,所以催生了spring团队开发了Gateway项目。 Zuul: 使用的…...

专访华西二院吴邦华:隐私计算+AI全栈技术,构筑智慧医院建设的坚实数据底座|爱分析访谈

从IT时代步入DT时代,医疗大数据成为智慧医院建设的重要驱动力。经过多年信息化系统建设,很多医院已经积累了大量的医疗数据资源,但由于各业务系统间数据孤岛化严重、系统架构落后、数据缺乏深度治理等问题存在,导致现有数据深度及…...

《C++ Primer Plus》第18章:探讨 C++ 新标准(6)

可变参数模板 可变参数模板(variadic template)让您能够创建这样的模板函数和模板类,即可接收可变数量的参数。这里介绍可变参数模板函数。例如,假设要编写一个函数,它可接受任意数量的参数,参数的类型只需…...

.Net Core中使用是SQL Server的邮件发送功能

.Net Core中使用是sqlserver的邮件发送功能准备需求启用SQL Server的电子邮件功能检查和测试在.net Core中调用在sqlsrver的管理中有一个数据库邮件功能,再此可以使用sqlserver来自动发送一些邮件,但是有一些需要插入附件的邮件则需要使用程序代码来解决,下面就是使用C#来调用s…...

Nginx优化服务和防盗链

Nginx优化服务和防盗链一、长连接1、修改主配置文件2、测试3、在主配置文件添加4、验证二、Nginx第三方模块1、开源的echo模块2、查看是否成功3、加echo模块步骤4、网页测试验证三、搭建虚拟主机1、编译安装好nginx后,对主配置文件进行修改2、创建文件3、验证四、防…...

B树与B+树

认识了解MySQL中的B树B树引出什么是B树什么是B树B树的优点B树引出 在MySQL中,如果我们设置了主键, 那么对于该列表中的数据就有了一个索引,插入表中数据的主键值不能重复,而且不能为空. 那当我们插入数据的时候, 它是如何通过索引来判断主键值是否重复的呢? 我们想到它肯定是…...

QEMU网络配置

文章目录1. 前言2. 测试环境3. 配置步骤3.1 host 配置3.1.1 检查 host 对 TUN/TAP 和 网桥的支持情况3.1.2 网桥一端的建立:创建网桥设备,并添加 host 网卡到网桥3.1.3 网桥另一端的建立:TUN/TAP 配置3.2 guest 端的配置4. 参考链接1. 前言 …...

windows安装tomcat

这里写自定义目录标题tomcat官网下载安装包并解压环境变量配置启动tomcat访问http://localhost:8080/修复启动出现乱码问题tomcat官网下载安装包并解压 环境变量配置 系统环境变量新增: 变量名:CATALINA_HOME 变量值:tomcat的安装目录 编辑…...

刷题记录:牛客NC23051华华和月月种树 树链剖分+离线加点

传送门:牛客 题目描述: 华华看书了解到,一起玩养成类的游戏有助于两人培养感情。所以他决定和月月一起种一棵树。因为华华现在也是信息学高手了,所以他们种的树是信息学意义下的。 华华和月月一起维护了一棵动态有根树,每个点有一个权值。刚…...

年薪20W软件测试工程师必备的6大技能(建议收藏)

软件测试 随着软件开发行业的日益发展,岗位需求量和行业薪资都不断增长,想要入行的人也是越来越多,但不知道从哪里下手,今天,就给大家分享一下,软件测试行业都有哪些必会的方法和技术知识点,作…...

label-studio的使用教程(导入本地路径)

文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

Leetcode 3576. Transform Array to All Equal Elements

Leetcode 3576. Transform Array to All Equal Elements 1. 解题思路2. 代码实现 题目链接:3576. Transform Array to All Equal Elements 1. 解题思路 这一题思路上就是分别考察一下是否能将其转化为全1或者全-1数组即可。 至于每一种情况是否可以达到&#xf…...

C# 类和继承(抽象类)

抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...

Pinocchio 库详解及其在足式机器人上的应用

Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库,专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性,并提供了一个通用的框架&…...

Python ROS2【机器人中间件框架】 简介

销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...

C#中的CLR属性、依赖属性与附加属性

CLR属性的主要特征 封装性: 隐藏字段的实现细节 提供对字段的受控访问 访问控制: 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性: 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑: 可以…...

莫兰迪高级灰总结计划简约商务通用PPT模版

莫兰迪高级灰总结计划简约商务通用PPT模版,莫兰迪调色板清新简约工作汇报PPT模版,莫兰迪时尚风极简设计PPT模版,大学生毕业论文答辩PPT模版,莫兰迪配色总结计划简约商务通用PPT模版,莫兰迪商务汇报PPT模版,…...

NPOI Excel用OLE对象的形式插入文件附件以及插入图片

static void Main(string[] args) {XlsWithObjData();Console.WriteLine("输出完成"); }static void XlsWithObjData() {// 创建工作簿和单元格,只有HSSFWorkbook,XSSFWorkbook不可以HSSFWorkbook workbook new HSSFWorkbook();HSSFSheet sheet (HSSFSheet)workboo…...

论文阅读:Matting by Generation

今天介绍一篇关于 matting 抠图的文章,抠图也算是计算机视觉里面非常经典的一个任务了。从早期的经典算法到如今的深度学习算法,已经有很多的工作和这个任务相关。这两年 diffusion 模型很火,大家又开始用 diffusion 模型做各种 CV 任务了&am…...

数据结构第5章:树和二叉树完全指南(自整理详细图文笔记)

名人说:莫道桑榆晚,为霞尚满天。——刘禹锡(刘梦得,诗豪) 原创笔记:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 上一篇:《数据结构第4章 数组和广义表》…...