当前位置: 首页 > news >正文

用VAE生成图像

用VAE生成图像

  • 自编码器AE,auto-encoder
  • VAE
    • 讲讲为什么是log_var
    • 为什么要用重参数化技巧
  • 用VAE生成图像

变分自编码器是自编码器的改进版本,自编码器AE是一种无监督学习,但它无法产生新的内容,变分自编码器对其潜在空间进行拓展,使其满足正态分布,情况就大不一样了。

自编码器AE,auto-encoder

自编码器是通过对输入X进行编码后得到一个低维的向量z,然后根据这个向量还原出输入X。通过对比X与X∼\overset{\sim}{X}X的误差,再利用神经网络去训练使得误差逐渐减小,从而达到非监督学习的目的。
下图为AE的架构图:
在这里插入图片描述
自编码器不能随意产生合理的潜在变量,从而导致它无法产生新的内容。因为潜在变量Z都是编码器从原始图片中产生的。为了解决这一问题,研究人员对潜在空间Z(潜在变量对应的空间) 增加一些约束,使 Z 满足正态分布,由此就出现了VAE模型, VAE对编码器添加约束,就是强迫它产生服从单位正态分布的潜在变量。正是这种约束,把VAE和 AE 区分开来。

VAE

变分自编码器关键一点就是增加一个对潜在空间Z的正态分布约束,如何确定这个正态分布就成为主要目标,我们知道要确定正态分布,只要确定其两个参数: 均值μ\muμ和标准差σ\sigmaσ。那如何确定 μ,σ\mu, \sigmaμ,σ呢?用一般的方法或估计比较麻烦效果也不好,研究人员发现**用神经网络去拟合,简单效果也不错。**下图为VAE的架构图:
在这里插入图片描述
上图中,模块①的功能把输入样本X通过编码器输出两个m维向量(μ,log_var\mu, \mathrm{log\_var}μ,log_var), 这两个向量是潜在空间(假设满足正态分布)的两个参数(相当于均值和方差)。那么如何从这个潜在空间采样一个点 Z ?

这里假设潜在正态分布能生成输入图像,从标准正态分布 N(0, 1)中采样一个 ϵ\epsilonϵ(模块②的功能), 然后使
Z=μ+elog_var∗ϵZ=\mu + e^{log\_var}*\epsilonZ=μ+elog_varϵ
这也是模块③的主要功能。
Z是从潜在空间抽取的一个向量,Z通过解码器生成一个样本X∼\overset{\sim}{X}X, 这是模块④的功能。这里的 ϵ\epsilonϵ 是随机采样的,这就可以保证潜在空间的连续性,良好的结构性。而这些特性使得潜在空间的每个方向都表示数据中有意义的变化方向。

以上的这些步骤构成整个网络的前向传播过程,那反向传播应如何进行?要确定反向传播就会设计损失函数,损失函数是衡量模型优劣的主要指标。这里我们需要从以下两个方面进行衡量。
1)生成的新图像与原图像的相似度;
2)隐含空间的分布与正态分布的相似度。

度量图像的相似度一般采用交叉熵(如nn.BCELoss) , 度量两个分布的相似度一般采用KL散度(Kullback-Leibler divergence)。这两个度量的和构成了整个模型的损失函数。

以下是损失函数的具体代码,VAE损失函数的推导过程可以参考原论文

# 定义重构损失函数及KL散度
reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
kl_div = -0.5*torch.sum(1+log_var-mu.pow(2)-log_var.exp())
# 两者相加得总损失
loss = reconst_loss + kl_div

讲讲为什么是log_var

这里可以看成 log_var = log⁡σ\log \sigmalogσ,所以Z=μ+elog_var∗ϵZ=\mu + e^{log\_var}*\epsilonZ=μ+elog_varϵ也就是Z=μ+σ∗ϵZ=\mu + \sigma*\epsilonZ=μ+σϵ
其中ϵ∼N(0,1)\epsilon\sim N(0,1)ϵN(0,1), 这里涉及到重参数化reparameterization。

为什么要用重参数化技巧

如果想从高斯分布N(μ,σ2)N(\mu,\sigma^{2})N(μ,σ2)中采样,可以先从标准分布N(0,1)N(0,1)N(0,1)采样出 ϵ\epsilonϵ , 再得到 Z=σ∗ϵ+μZ = \sigma*\epsilon+\muZ=σϵ+μ.
这样做的好处是

  1. 如果直接对N(μ,σ2)N(\mu,\sigma^{2})N(μ,σ2)进行采样得到Z,则Z无法对μ,σ\mu,\sigmaμ,σ进行求偏导

  2. 将随机性转移到了 ϵ\epsilonϵ 这个常量上,而 σ\sigmaσμ\muμ则当做仿射变换网络的一部分,这样得到的Z=σ∗ϵ+μZ = \sigma*\epsilon+\muZ=σϵ+μ,则Z就可以对μ,σ\mu,\sigmaμ,σ进行求偏导来计算损失函数,进行求梯度,进行BP。

用VAE生成图像

下面将结合代码,用pytorch实现,为便于说明起见,数据集采用MNIST,整个网络架构如下图所示。
VAE网络架构图

# 1. 导入必要的包
import os 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision 
from torchvision import transforms 
from torchvision.utils import save_image# 2. 定义一些超参数
image_size = 784 # 28*28 
h_dim = 400 
z_dim = 20 
num_epochs = 30 
batch_size = 128 
learning_rate = 0.001 # 如果没有文件夹就创建一个文件夹
sample_dir = 'samples'
if not os.path.exists(sample_dir):os.makedirs(sample_dir)
  1. 对数据集进行预处理,如转换为Tensor, 把数据集转换为循环,可批量加载的数据集
# 只下载训练数据集即可
# 下载MNIST训练集
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(),download=True)# 数据加载
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
  1. 构建VAE模型,主要由Encoder和Decoder两部分组成

# 定义AVE模型
class VAE(nn.Module):def __init__(self, image_size=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.fc1 = nn.Linear(image_size, h_dim)self.fc2 = nn.Linear(h_dim, z_dim)self.fc3 = nn.Linear(h_dim, z_dim)self.fc4 = nn.Linear(z_dim, h_dim)self.fc5 = nn.Linear(h_dim, image_size)def encoder(self, x):h = F.relu(self.fc1(x))return self.fc2(h), self.fc3(h)# 用mu, log_var生成一个潜在空间点z, mu, log_var为两个统计参数,我们假设# 这个假设分布能生成图像def reparameterize(self, mu, log_var):std = torch.exp(log_var/2)eps = torch.randn_like(std)return mu + eps * std def decoder(self, z):h = F.relu(self.fc4(z))return F.sigmoid(self.fc5(h))def forward(self, x):mu, log_var = self.encoder(x)z = self.reparameterize(mu, log_var)x_reconst = self.decoder(z)return x_reconst, mu, log_var 
  1. 选择GPU和优化器
# 选择GPU和优化器
torch.cuda.set_device(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  1. 训练模型,同时保存原图像与随机生成的图像
# 训练模型,同时保存原图像与随机生成的图像
for epoch in range(num_epochs):for i, (x, _) in enumerate(data_loader):# 获取样本,并前向传播x = x.to(device).view(-1, image_size)x_reconst, mu, log_var = model(x)# 计算重构损失和KL散度(KL散度用于衡量两种分布的相似程度)# KL散度的计算可以参考https://shenxiaohai.me/2018/10/20/pytorch-tutorial-advanced-02/reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())# 反向传播和优化loss = reconst_loss + kl_div optimizer.zero_grad()loss.backward()optimizer.step()if (i+1)%100 == 0:print('Epoch [{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div:{:.4f}'.format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))# 利用训练的模型进行测试with torch.no_grad():# 保存采样图像,即潜在向量z通过解码器生成的新图像# 随机生成的图像z = torch.randn(batch_size, z_dim).to(device)out = model.decoder(z).view(-1, 1, 28, 28)save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))# 保存重构图像,即原图像通过解码器生成的图像out, _, _ = model(x)x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))
  1. 展示原图像及重构图像
#显示图片
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg 
import numpy as np recons_path = './samples/reconst-30.png'
Image = mpimg.imread(recons_path)
plt.imshow(Image)
plt.axis('off')
plt.show()
# reconst
# 奇数列为原图像,欧数列为原图像重构的图像,可以看出重构效果还不错。

在这里插入图片描述
8. 由潜在空间通过解码器生成的新图像,这个图像效果也不错


# sampled
# 为由潜在空间通过解码器生成的新图像,这个图像效果也不错
genPath = './samples/sampled-30.png'
Image = mpimg.imread(genPath)
plt.imshow(Image)
plt.axis('off')
plt.show()

在这里插入图片描述
总结:这里构建网络主要用全连接层,有兴趣的读者,可以把卷积层,如果编码层使用卷积层(如nn.Conv2d), 则解码器就需要使用反卷积层(如nn.ConvTranspose2d)。

相关文章:

用VAE生成图像

用VAE生成图像自编码器AE,auto-encoderVAE讲讲为什么是log_var为什么要用重参数化技巧用VAE生成图像变分自编码器是自编码器的改进版本,自编码器AE是一种无监督学习,但它无法产生新的内容,变分自编码器对其潜在空间进行拓展&#…...

你只会说MVC模型是什么但是不会实现?今天带你走通Web、Servlet、MVC、SpringMVC。代码演示很清晰

文章目录HTTP请求和HTTP响应从0手写一个Web服务器,看看能有多累人使用Servlet实现一个服务器,看看多简单Serlvet的创建Servlet的运行Servlet的其他问题Servlet这么爽,我们简单地探索一下它的原理JSP跟Servlet合作啦,我们来看一下他…...

C++中邻接矩阵、邻接表、链式前向星具体用法及讲解

图论在提高组中几乎占据半壁江山,而今天要讲的就是如何存储一个图一.邻接矩阵原理要建立一个图,根本的要素就是边和点而想要让计算机存储边和点就需要用到一些数据结构邻接矩阵是最简单的他使用了一个二维数组,来表示一个图假设数组名为map那…...

appium的安装详解

安装appium 爬虫手机APP需要实现自动化,所以要使用appnium来实现点击,输入,滑动等操作。由于appnium的安装较为繁琐,所以特意整理一篇文章来展示安装的详细过程过程中。 安装appnium共有3个步骤 安装 Android SDK安装 JDK安装 …...

STM32之 串口

串口通信串行接口简称串口,也称串行通信接口或串行通讯接口(通常指COM接口),是采用串行通信方 式的扩展接口。串行接口(Serial Interface)是指数据一位一位地顺序传送。其特点是通信线路简 单,只…...

CSDN 编程竞赛三十三期题解

竞赛总览 CSDN 编程竞赛三十三期题解:比赛详情 (csdn.net) 竞赛题解 题目1、奇偶排序 给定一个存放整数的数组,重新排列数组使得数组左边为奇数,右边为偶数(奇数和偶数的顺序根据输入的数字顺序排列)。 第七期竞赛…...

逆向练习之 mingyue.exe wp

目录 一.查壳 二.主函数 三.operate函数 四.storage函数及4618和4620指针功能的解释 五.judge函数 六.求解flag 七.其他--ida字符识别问题 一.查壳 64位无壳 二.主函数 1.这里的pointer_4618和4620是两个相邻的八字节内存单元,其中4620是字符串链表表头head 2.puts和s…...

LeetCode 热题 HOT 100 Java 题解 -- Part 3

练习地址 Part 1 : https://blog.csdn.net/qq_41080854/article/details/128829494 Part 2 : https://blog.csdn.net/qq_41080854/article/details/129278336 LeetCode 热题 HOT 100 Java 题解 -- Part 376. 最佳买卖股票时机含冷冻期77. 戳气球78. 零钱兑换79. 打家劫舍 III…...

QML键盘事件

在QML中,当有一个按键按下或释放时,会产生一个键盘事件,将其传递给获得有焦点的QML项目(讲focus属性设置为true,则获得焦点)。 按键处理的基本流程: Qt接收密钥操作并生成密钥事件。如果 QQuic…...

跨域问题怎么解决

解决跨域,原因:域名不同,域名相同端口不同;二级域名不同 什么是跨域? 就是两个项目之间通讯,如果访问的域名与ajax访问的地址不一致情况,默认情况浏览器有一个安全机制。 postman不一定能测试…...

微服务网关Gateway和Zuul的区别

spring-cloud-Gateway是spring-cloud的一个子项目。而zuul则是netflix公司的项目,只是spring将zuul集成在spring-cloud中使用而已。 因为zuul2.0连续跳票和zuul1的性能表现不是很理想,所以催生了spring团队开发了Gateway项目。 Zuul: 使用的…...

专访华西二院吴邦华:隐私计算+AI全栈技术,构筑智慧医院建设的坚实数据底座|爱分析访谈

从IT时代步入DT时代,医疗大数据成为智慧医院建设的重要驱动力。经过多年信息化系统建设,很多医院已经积累了大量的医疗数据资源,但由于各业务系统间数据孤岛化严重、系统架构落后、数据缺乏深度治理等问题存在,导致现有数据深度及…...

《C++ Primer Plus》第18章:探讨 C++ 新标准(6)

可变参数模板 可变参数模板(variadic template)让您能够创建这样的模板函数和模板类,即可接收可变数量的参数。这里介绍可变参数模板函数。例如,假设要编写一个函数,它可接受任意数量的参数,参数的类型只需…...

.Net Core中使用是SQL Server的邮件发送功能

.Net Core中使用是sqlserver的邮件发送功能准备需求启用SQL Server的电子邮件功能检查和测试在.net Core中调用在sqlsrver的管理中有一个数据库邮件功能,再此可以使用sqlserver来自动发送一些邮件,但是有一些需要插入附件的邮件则需要使用程序代码来解决,下面就是使用C#来调用s…...

Nginx优化服务和防盗链

Nginx优化服务和防盗链一、长连接1、修改主配置文件2、测试3、在主配置文件添加4、验证二、Nginx第三方模块1、开源的echo模块2、查看是否成功3、加echo模块步骤4、网页测试验证三、搭建虚拟主机1、编译安装好nginx后,对主配置文件进行修改2、创建文件3、验证四、防…...

B树与B+树

认识了解MySQL中的B树B树引出什么是B树什么是B树B树的优点B树引出 在MySQL中,如果我们设置了主键, 那么对于该列表中的数据就有了一个索引,插入表中数据的主键值不能重复,而且不能为空. 那当我们插入数据的时候, 它是如何通过索引来判断主键值是否重复的呢? 我们想到它肯定是…...

QEMU网络配置

文章目录1. 前言2. 测试环境3. 配置步骤3.1 host 配置3.1.1 检查 host 对 TUN/TAP 和 网桥的支持情况3.1.2 网桥一端的建立:创建网桥设备,并添加 host 网卡到网桥3.1.3 网桥另一端的建立:TUN/TAP 配置3.2 guest 端的配置4. 参考链接1. 前言 …...

windows安装tomcat

这里写自定义目录标题tomcat官网下载安装包并解压环境变量配置启动tomcat访问http://localhost:8080/修复启动出现乱码问题tomcat官网下载安装包并解压 环境变量配置 系统环境变量新增: 变量名:CATALINA_HOME 变量值:tomcat的安装目录 编辑…...

刷题记录:牛客NC23051华华和月月种树 树链剖分+离线加点

传送门:牛客 题目描述: 华华看书了解到,一起玩养成类的游戏有助于两人培养感情。所以他决定和月月一起种一棵树。因为华华现在也是信息学高手了,所以他们种的树是信息学意义下的。 华华和月月一起维护了一棵动态有根树,每个点有一个权值。刚…...

年薪20W软件测试工程师必备的6大技能(建议收藏)

软件测试 随着软件开发行业的日益发展,岗位需求量和行业薪资都不断增长,想要入行的人也是越来越多,但不知道从哪里下手,今天,就给大家分享一下,软件测试行业都有哪些必会的方法和技术知识点,作…...

【存储】RAID2.0+、多路径技术、磁盘可靠性技术

RAID2.0RAID 2.0技术RAID技术发展RAID 2.0软件逻辑对象RAID 2.0基本原理硬盘域Storage Pool & TierDisk Group(DG)LD(逻辑磁盘)Chunk(CK)Chunk Group(CKG)ExtentGrainVolume &am…...

Vue 2

文章目录1. 简介2. 第一个Vue程序3. 指令3.1 判断循环3.2 操作属性3.3 绑定事件3.4 表单中数据双向绑定3.5 其他内置指令3.6 自定义指令4. 组件4.1 全局注册4.2 局部注册4.3 组件通讯4.4 单文件组件5. 组件插槽5.1 单个插槽5.2 具名插槽5.3 作用域插槽6. 内置组件6.1 component…...

Ubuntu 安装 Docker Engine

【参考】Install Docker Engine on Ubuntu | Docker Documentation: https://docs.docker.com/engine/install/ubuntu/ 【参考】Docker CE 镜像源站-阿里云开发者社区 https://developer.aliyun.com/article/110806 【规范】模仿 Docker 文档,Ubuntu, Docker 首字母…...

SpringBoot入门 - 添加内存数据库H2

上文我们展示了通过学习经典的MVC分包结构展示了一个用户的增删查改项目,但是我们没有接入数据库;本文将在上文的基础上,增加一个H2内存数据库,并且通过Spring 提供的数据访问包JPA进行数据查询。准备知识点在介绍通过Spring JPA接…...

高质量数字化转型创新发展大会暨中国信通院“铸基计划”年度会议成功召开

2023年3月3日,由中国信通院主办的高质量数字化转型创新发展大会暨中国信通院“铸基计划”年度会议在北京成功召开。本次大会深度展示了中国信通院在数字化领域的工作成果,并全面展望了2023年行业的数字化发展趋势。同时,大会发布了中国信通院…...

2023年如何通过软考初级程序员?

初级的考试难度不大,稍微有点编程基础,认真备考应该没什么大问题。 先清楚大纲: 高效备考!理清考点,针对性复习 科目一:综合知识 75道单项选择题,1题1分,时长150分钟;…...

视频自动播放的实现与问题解决

一、前言 页面加载一个视频并且自动播放,这个需求看起来非常简单,实现起来感觉也非常简单;但是,实际做起来还是有几处容易产生问题的地方卡住进度。本文讨论基于Vue3的项目在实现页面加载视频后的自动播放遇到的几个问题。 二、页面实现 页面实现非常简单。在页面上放置一个…...

ThreadLocal 理解及面试

一、ThreadLocal 引用关系 图解关系说明: 每个线程拥有自己的 ThreadLocalMap 属性;ThreadLocalMap 的存储结构为 Entry[] 数组;Entry的Key是ThreadLocal类型且弱引用指向ThreadLocal对象,Value是我们自己定义的泛型值对象&#…...

巾帼绽芬芳 一起向未来(中篇)

编者按:为了隆重纪念纪念“三八”国际妇女节113周年,快来与你全方位、多层次分享交流“三八”国际妇女节的前世今生。分上篇(节日简介、节日发展和节日意义)、中篇(节日活动宗旨和世界各国庆祝方式)和下篇&…...

Qt学习2-Qt Creator新建项目小tips(哔站视频学习记录)

放送两个小tips: 1、MinGW和MSVC的区别 QT学习笔记(二):QT MinGW 和 MSVC 编译方式_Leon_Chan0的博客-CSDN博客 2、如何安装QT对应版本的MSVC (1)问题描述:Qt5.12.8支持MSVC2015和MSVC2017,但是系统安装的是Visual…...

wordpress 多菜单/百度文库官网入口

procstat当前服务器进程性能参数(所有类型的进程都有)cpu_usage:当前服务器进程cpu的占用率,所有子线程的cpu占用之后,每个核算100%memory_rss:当前服务器进程占用的物理内存cpu_thread当前服务器进程的各个子线程的性能参数(所有…...

装修公司网站开发/中国营销网

C操作符的优先级 C操作符的优先级 操作符及其结合性 功能 用法 L L L:: :: ::全局作用域 类作用域 名字空间作用域::name class::name namespace::nameL L L L L. -> [] () ()成员选择 成员选择 下标 函数调用 类型构造object.member pointer->member variable[exp…...

专业网站建设首选公司/网络运营具体做什么

http://blog.csdn.net/0210/article/details/5437368 http://blog.csdn.net/vebasan/article/details/5515235 ifconfig 如果不能用,就用/sbin/ifconfig ifconfig -a 所有的网卡情况 ifconfig 就会出现当前正在使用的网卡的情况, eth0 eth0:0 //和上…...

海南海口做网站/深圳网站设计实力乐云seo

记录一个小知识点: 我们会看到这样利用循环的代码: for _ in range(n): _ 在这里是什么意思? _ 是占位符:表示不在意变量的值,只是用于循环遍历n次,和使用i没什么区别,只不过使用_就表示我不在…...

济宁建设局官方网站/seo技术培训班

问题求解1: 从一个 44 的棋盘(不可旋转)中选取不在同一行也不在同一列上的两个方格,共有____72_____种方法。 假设选择第一行,共有4个格子可以选择,然后从剩余的3行中进行选择,有4X3种可能。…...

wordpress慢 数据库/seo网站优化助理

本文简介decorator模块是 Michele Simionato 为简化python的decorator的使用难度而开发的,使用它,您可以更加容易的使用decorator机制写出可读性、可维护性更好的代码。本文大部分翻译自下面这篇文档: www.phyast.pitt.edu/~micheles/python/documentati…...