当前位置: 首页 > news >正文

【扩散模型 李宏毅B站教学以及基础代码运用】

李宏毅教学视频:
Link1

B站DDPM公式推导以及代码实现:
Link2

这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。

文章目录

    • 扩散模型概念:
    • Diffusion Model工作原理:
    • 影像生成模型本质上的共同目标
    • B站简单示例代码讲解

扩散模型概念:

就像石头里面已经有了雕塑,只需要看我们怎么把其他多余的部分去掉。
在这里插入图片描述
注意观察,我们每一个Denoise阶段都不一样,因为每一个阶段传入的图片以及需要处理的noise都不一样,并且直接产生图片比直接产生噪音更难,所以我们通过预测noise来解决问题。
在这里插入图片描述

比如下图所示:step2是我们加的噪声,那么传入input和2的时候就希望预测出gt了,然后进行相减得到step1的图片。
在这里插入图片描述

Diffusion Model工作原理:

VAE和Diffusion的区别
在这里插入图片描述
先看整个训练过程:
在这里插入图片描述

实际结果和我们想的是不一样的。训练时通过X0和噪声得到一个图,逆向的时候输入t和生成的图来得到噪音。想象的是一点一点加入噪音,实际上是直接加进去的。在这里插入图片描述
推断时刻:theat是带有参数的网络。
在这里插入图片描述

影像生成模型本质上的共同目标

通过采样一个高深distribution生成一个图片。希望生成的图片和真实的图片的distribution很接近。
在这里插入图片描述
那么怎么衡量这两个分布的接近程度呢?多数采用的都是Maximum liklihood Estimation.
我们希望我们采样的数据能够通过theta网络计算出来的概率越大越好。 在这里插入图片描述
通过数学变换,将概率最大变为Pdata和Ptheat这两个distribution的KL散度最小。
在这里插入图片描述
VAE的下界
Ptheat(x)表示:通过theta产生x的概率。
在这里插入图片描述

在这里插入图片描述
DDPM计算Ptheta(x)的方法 下图表示产生X0的概率。
在这里插入图片描述
两者对比
在这里插入图片描述
接下来需要计算q(x1|x0)此类公式。
计算方法:X1到X2的计算方法在论文中有提及。
在这里插入图片描述
两个高斯分布都是服从N(0,1),相加的话还是一个高斯分布,并且还是服从N(0,1),只是前面系数会发生变化。系数的话是根号下面数字相加。所以相加之后均值还是为0,方差a方加b方即可,这个在另外一个视频里面有讲解。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
经过一番推导之后得到:
在这里插入图片描述
之后计算最下面三项:
在这里插入图片描述
通过以下推导:
在这里插入图片描述
之后通过X0,Xt可以得到Xt-1的分布。
在这里插入图片描述
可以看到前面一项的mean 和 variance是固定的,第二项的variance也是固定的,因此我们需要把第二项的mean变得和第一项的接近。
在这里插入图片描述
那么怎么minimiaze这个mean呢?希望用Xt去预测出来那个mean。
在这里插入图片描述
经过推导:
在这里插入图片描述
最终得到下图:
在这里插入图片描述
里面beta可以学习,但是效果不好,所以使用线性固定。最后加上一个噪声猜测是为了增强鲁棒性,并且本身就是从噪声开始,不加噪声的话可能不会生成图片。

B站简单示例代码讲解

# 加载数据集
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torchs_curve,_ = make_s_curve(10**4,noise=0.1)
print(np.shape(s_curve))
s_curve = s_curve[:,[0,2]]/10.0print("shape of s:",np.shape(s_curve))data = s_curve.Tfig,ax = plt.subplots()
ax.scatter(*data,color='blue',edgecolor='white');ax.axis('off')dataset = torch.Tensor(s_curve).float()

在这里插入图片描述

# 2确定超参数的值
num_steps = 100
#制定每一步的beta
betas = torch.linspace(-6,6,num_steps)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5#计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1-betas
alphas_prod = torch.cumprod(alphas,0)
# print(alphas_prod)
alphas_prod_p = torch.cat([torch.tensor([1]).float(),alphas_prod[:-1]],0)
# print(alphas_prod_p)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==\
alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape\
==one_minus_alphas_bar_sqrt.shape
print("all the same shape",betas.shape)、确定扩散过程任意时刻的采样值#3 计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):"""可以基于x[0]得到任意时刻t的x[t]"""noise = torch.randn_like(x_0)alphas_t = alphas_bar_sqrt[t]alphas_1_m_t = one_minus_alphas_bar_sqrt[t]return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声
j
# 4 演示原始数据分布加噪100步后的结果num_shows = 20
fig,axs = plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='black')#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):j = i//10k = i%10q_i = q_x(dataset,torch.tensor([i*num_steps//num_shows]))#生成t时刻的采样数据axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')axs[j,k].set_axis_off()axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')

在这里插入图片描述

# 5 编写拟合逆扩散过程高斯分布的模型import torch
import torch.nn as nn
​
class MLPDiffusion(nn.Module):def __init__(self,n_steps,num_units=128):super(MLPDiffusion,self).__init__()self.linears = nn.ModuleList([nn.Linear(2,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,2),])self.step_embeddings = nn.ModuleList([nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),])def forward(self,x,t):
#         x = x_0for idx,embedding_layer in enumerate(self.step_embeddings):t_embedding = embedding_layer(t)x = self.linears[2*idx](x)x += t_embeddingx = self.linears[2*idx+1](x)x = self.linears[-1](x)return x

loss_fn 就是Lsimple得表达式。通过传入参数,生成一个随机噪声,并且送入model里面,那么上面也讲了,model的作用是根据X0,和t预测出我们应该减去的噪声,所以损失函数就是用生成的噪声减去预测的噪声。
在这里插入图片描述

# 6 编写训练的误差函数
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):"""对任意时刻t进行采样计算loss"""batch_size = x_0.shape[0]#对一个batchsize样本生成随机的时刻tt = torch.randint(0,n_steps,size=(batch_size//2,))t = torch.cat([t,n_steps-1-t],dim=0)t = t.unsqueeze(-1)#x0的系数a = alphas_bar_sqrt[t]#eps的系数aml = one_minus_alphas_bar_sqrt[t]#生成随机噪音epse = torch.randn_like(x_0)#构造模型的输入x = x_0*a+e*aml#送入模型,得到t时刻的随机噪声预测值output = model(x,t.squeeze(-1))#与真实噪声一起计算误差,求平均值return torch.pow((e - output),2).mean()
# 7、编写逆扩散采样函数(inference)def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):"""从x[T]恢复x[T-1]、x[T-2]|...x[0]"""cur_x = torch.randn(shape)x_seq = [cur_x]for i in reversed(range(n_steps)):cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)x_seq.append(cur_x)return x_seq
​
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):"""从x[T]采样t时刻的重构值"""t = torch.tensor([t])coeff = betas[t] / one_minus_alphas_bar_sqrt[t]eps_theta = model(x,t)mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))z = torch.randn_like(x)sigma_t = betas[t].sqrt()sample = mean + sigma_t * zreturn (sample)
# 8、开始训练模型,打印loss及中间重构效果seed = 1234class EMA():"""构建一个参数平滑器"""def __init__(self,mu=0.01):self.mu = muself.shadow = {}def register(self,name,val):self.shadow[name] = val.clone()def __call__(self,name,x):assert name in self.shadownew_average = self.mu * x + (1.0-self.mu)*self.shadow[name]self.shadow[name] = new_average.clone()return new_averageprint('Training model...')
batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch = 4000
plt.rc('text',color='blue')
​
model = MLPDiffusion(num_steps)#输出维度是2,输入是x和step
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)for t in range(num_epoch):for idx,batch_x in enumerate(dataloader):loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.)optimizer.step()if(t%100==0):print(loss)x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)fig,axs = plt.subplots(1,10,figsize=(28,3))for i in range(1,11):cur_x = x_seq[i*10].detach()axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');axs[i-1].set_axis_off();axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')

最后的演示

动画演示扩散过程和逆扩散过程import io
from PIL import Image
​
imgs = []
for i in range(100):plt.clf()q_i = q_x(dataset,torch.tensor([i]))plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);plt.axis('off');img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)imgs.append(img)
mg = Image.open(img_buf)reverse.append(img)
reverse = []
for i in range(100):plt.clf()cur_x = x_seq[i].detach()plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);plt.axis('off')img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)reverse.append(img)
​
imgs = imgs +reverse
imgs[0].save("diffusion.gif",format='GIF',append_images=imgs,save_all=True,duration=100,loop=0)

相关文章:

【扩散模型 李宏毅B站教学以及基础代码运用】

李宏毅教学视频: Link1 B站DDPM公式推导以及代码实现: Link2 这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。 文章目录 扩散模型概念:Diffusion Model工作原理:影像生成模型本质上的共同目标B站…...

SpringBoot隐藏文件

1.设置 2.输入file Types 3.点击忽略文件或者文件夹 4.成功...

常见数据库介绍对比之SQL关系型数据库

1.关系型数据库介绍 关系型数据库是一种基于关系模型的数据库,它使用表格来组织和存储数据。下面是一些常见的关系型数据库: 1.1. MySQL MySQL是一种开源的关系型数据库管理系统(RDBMS),广泛用于Web应用程序和企业级…...

OLED透明屏模块:引领未来显示技术的突破

OLED透明屏模块作为一项引领未来显示技术的突破,以其独特的特点和卓越的画质在市场上引起了广泛关注。 根据行业报告,预计到2025年,OLED透明屏模块将占据智能手机市场的20%份额,并在汽车导航系统市场中占据30%以上份额。 那么&am…...

Python_操作记录

1、Pandas读取数据文件(以文本文件作为示例),sep表示间隔,headerNone表示无标题行 df pd.read_table("data/youcans3.dat", sep"\t", headerNone) 2、线性规划问题求解 1)问题定义,…...

常用激活函数整理

最近一边应付工作,一边在补足人工智能的一些基础知识,这个方向虽然新兴,但已是卷帙浩繁,有时不知从何入手,幸亏有个适合基础薄弱的人士学习的网站,每天学习一点,积跬步以至千里吧。有像我一样学…...

uniapp 地图跳转到第三方导航软件 直接打包成apk

// 判断是否存在导航软件judgeHasExistNavignation() {let navAppParam [{pname: com.baidu.BaiduMap,action: baidumap://}, // 百度{pname: com.autonavi.minimap,action: iosamap://}, // 高德{pname: com.tencent.map,action: tencentmap://}, // 腾讯];return navAppPara…...

CentOS 8 通过YUM方式升级最新内核

CentOS 8 通过YUM方式升级最新内核 查看当前内核 uname -r 4.18.0-193.6.3.el8_2.x86_64导入 ELRepo 仓库的公钥: rpm --import https://www.elrepo.org/RPM-GPG-KEY-elrepo.org安装升级内核相关的yum源仓库(安装 ELRepo 仓库的 yum 源) yum install https://www…...

java 版本企业招标投标管理系统源码+功能描述+tbms+及时准确+全程电子化

功能描述 1、门户管理:所有用户可在门户页面查看所有的公告信息及相关的通知信息。主要板块包含:招标公告、非招标公告、系统通知、政策法规。 2、立项管理:企业用户可对需要采购的项目进行立项申请,并提交审批,查看所…...

Python爬虫数据存哪里|数据存储到文件的几种方式

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 爬虫请求解析后的数据,需要保存下来,才能进行下一步的处理,一般保存数据的方式有如下几种: 文件:txt、csv、excel、json等,保存数据量小。 关系型数据库…...

软件测试/测试开发丨Web自动化 测试用例流程设计

点此获取更多相关资料 本文为霍格沃兹测试开发学社学员学习笔记分享 原文链接:https://ceshiren.com/t/topic/27173 一、测试用例通用结构回顾 1.1、现有测试用例存在的问题 可维护性差可读性差稳定性差 1.2、用例结构设计 测试用例的编排测试用例的项目结构 1…...

git撤销修改命令

要撤销Git中尚未提交的所有修改,可以使用以下几种方法: 1、使用git checkout命令丢弃工作目录的修改,重置工作目录中所有文件的修改。 git checkout . 2、使用git reset命令重置暂存区和工作目录, 重置暂存区和工作目录,回到最后一次提交后的状态。 …...

EOCR-AR电机保护器自动复位的启用条件说明

为适用不同的现场使用需求,施耐德韩国公司推出了带有自动复位功能的模拟型电动机保护器-EOCR-AR。EOCR-AR电机保护器具有过电流、缺相、堵转保护功能,还可根据实际需要设置自动复位时间。 EOCR-AR自动复位的设置方法 如上图,R-TIME旋钮是自动…...

Apache nginx解析漏洞复现

文章目录 空字节漏洞安装环境漏洞复现 背锅解析漏洞安装环境漏洞复现 空字节漏洞 安装环境 将nginx解压后放到c盘根目录下: 运行startup.bat启动环境: 在HTML文件夹下有它的主页文件: 漏洞复现 nginx在遇到后缀名有php的文件时,…...

.NET之后,再无大创新

回想起来,2001年发布的.NET已经是距离最近的一次软件开发技术的整体创新了,后续的新技术就没有在各个端都这么成功的了。.NET是Windows平台下软件开发技术的巨大变革。在此之前,有VB、C(MFC)、JSP,在此之后…...

【大麦小米学量化】什么是量化交易?哪些人适合做量化交易?

系列文章目录 文章目录 系列文章目录学霸的梦想前言一、什么是量化交易?二、哪些人适合做量化交易?三、量化交易都需要掌握哪些技术和方法?总结 学霸的梦想 小米支棱着迷糊的眼睛,一脸懵逼的问大麦:“我说大麦哥哥&…...

计算机视觉的应用12-卷积神经网络中图像特征提取的可视化研究,让大家理解特征提取的全过程

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用12-卷积神经网络中图像特征提取的可视化研究,让大家理解特征提取的全过程。 要理解卷积神经网络中图像特征提取的全过程,我们可以将其比喻为人脑对视觉信息的处理过程。就像…...

el-table中点击跳转到详情页的两种方法

跳转的两种写法: 1.使用keep-alive使组件缓存,防止刷新时参数丢失 keep-alive 组件用于缓存和保持组件的状态,而不是路由参数。它可以在组件切换时保留组件的状态,从而避免重新渲染和加载数据。 keep-alive 主要用于提高页面性能和用户体验,而…...

RT-DETR个人整理向理解

一、前言 在开始介绍RT-DETR这个网络之前,我们首先需要先了解DETR这个系列的网络与我们常提及的anchor-base以及anchor-free存在着何种差异。 首先我们先简单讨论一下anchor-base以及anchor-free两者的差异与共性: 1、两者差异:顾名思义&…...

易点易动库存管理系统与ERP系统打通,帮助企业实现低值易耗品管理

现今,企业管理日趋复杂,无论是核心经营还是辅助环节,都需要依靠信息化手段来提升效率。而低值易耗品作为企业日常运营中的必需品,其管理也面临诸多挑战。传统做法效率低下,容易出错。如何通过信息化手段实现低值易耗品的高效管理,成为许多企业必顾及的一个课题。 易点易动作为…...

【笔试强训选择题】Day34.习题(错题)解析

作者简介:大家好,我是未央; 博客首页:未央.303 系列专栏:笔试强训选择题 每日一句:人的一生,可以有所作为的时机只有一次,那就是现在!!!&#xff…...

“现代”“修饰”卷积神经网络,何谓现代

一、“现代” vs “传统” 现代卷积神经网络(CNNs)与传统卷积神经网络之间存在一些关键区别。这些区别主要涉及网络的深度、结构、训练技巧和应用领域等方面。以下是现代CNNs与传统CNNs之间的一些区别: 深度: 传统CNNs&#xff1…...

XHTML基础知识了解

XHTML是一种严格符合XML规范的标记语言,它的基本语法和HTML类似,但是更加严谨和规范。XHTML的代码结构非常清晰,方便浏览器和搜索引擎解析。下面是一些XHTML的基础知识和代码示例: 声明文档类型(DTD) 在X…...

USB Server集中管控加密狗,浙江省电力设计院正在用

近日,软件加密狗的分散管理和易丢失性,给拥有大量加密狗的浙江省电力设计院带来了一系列的问题。好在浙江省电力设计院带及时使用了朝天椒USB Server方案,实现了加密狗的集中安全管控,避免了加密狗因为管理不善和遗失可能带来的巨…...

rust换源

在$HOME/.cargo/目录下建一个config文件。windows默认是C:\Users\user_name\.cargo。 config文件输入: [source.crates-io] registry "https://github.com/rust-lang/crates.io-index" # 使用 replace-with指明默认源更换为ustc源 replace-with ustc#…...

常见关系型数据库SQL增删改查语句

常见关系型数据库SQL增删改查语句: 创建表(Create Table): CREATE TABLE employees (id INT PRIMARY KEY,name VARCHAR(50),age INT,department VARCHAR(50) ); 插入数据(Insert Into): INSERT …...

OpenCV(二十七):图像距离变换

1.像素间距离 2.距离变换函数distanceTransform() void cv::distanceTransform ( InputArray src, OutputArray dst, int distanceType, int maskSize, int dstType CV_32F ) src:输入图像,数据类型为CV8U的单通道图像dst:输出图像,与输入图像…...

服务器就是一台电脑吗?服务器的功能和作用

服务器不仅仅是一台普通的电脑,它在功能和作用上有着显著的区别。下面是关于服务器的功能和作用的简要说明: 存储和共享数据:服务器可以用作数据存储和共享的中心。它们通常配备大容量的硬盘或固态硬盘,用于存储文件、数据库和其他…...

vue3实现塔罗牌翻牌

vue3实现塔罗牌翻牌 前言一、操作步骤1.布局2.操作3.样式 总结 前言 最近重刷诡秘之主,感觉里面的塔罗牌挺有意思,于是做了一个简单的塔罗牌翻牌动画(vue3vitets) 一、操作步骤 1.布局 首先我们定义一个整体的塔罗牌盒子&…...

分布式搜索引擎

1 DSL查询文档 elasticsearch的查询依然是基于JSON风格的DSL来实现的。 1.1.DSL查询分类 Elasticsearch提供了基于JSON的DSL(Domain Specific Language)来定义查询。常见的查询类型包括: 查询所有:查询出所有数据,一…...

免费云服务器官网/seo关键词优化排名公司

Django中路由的作用 ​ 其本质是URL与要为该URL调用的视图函数之间的映射表;你就是以这种方式告诉Django,对于客户端发来的某个URL调用哪一段逻辑代码对应执行 简单的路由配置 # Django1.0版本 from django.conf.urls import urlurlpatterns [url(正则表…...

做独立网站需要注意什么好/信息流广告代理商排名

3.7、环境配置(environments) 可以配置成适应多种环境,有助于将SQL映射应用于多种数据库之中。每个SqlSessionFactory实例只能选择一种环境(数据库),也就是做你想连接N个数据库,那么需要创建N个SqlSessionFactory实例…...

简历设计网官网入口/免费seo网站自动推广软件

奶制品。以低脂酸奶最佳,它富含钙质、多种维生素、蛋白质和钾元素。除此之外,酸奶中的益生菌更有助于保持体内菌群平衡。如果你不喜欢酸奶,脱脂牛奶和奶酪也是不错的选择。奶制品几乎包含了人体所需要的所有营养素,各种营养素之间…...

从零开始做一个网站需要多少钱/最新疫情新闻100字

计算轮廓点的最小凸包像素面积,最小外接圆的快速方法 计算轮廓点的最小凸包像素面积和求解最小外接圆的方法有很多,本文各举出一种比较简单且快速的方法,读者可根据实际情况定义返回值精度,本文采用的是整型精度,代码…...

唐山公司网站建设 中企动力/头条搜索

1、源代码安装nginx1)、所需环境:开发环境:Development toolsServer Platform DevelopmentAdditional Developmentpcre-develnginx-1.6.0.tar.gz2)、源码安装nginx:拆解源代码包到/usr/local/src/目录下创建系统账号和组nginx[roo…...

wordpress拍卖插件中文/人工智能培训机构排名前十

flask框架(四)1.蓝图的基本使用(掌握)作用:为了进行模块化开发特点:属于flask自带的,不需要安装扩展就能使用蓝图的使用流程1/创建蓝图对象(Blueprint)2/使用蓝图装饰视图函数3/将蓝图注册到app中(register_blueprint)user_blueBlueprint("user",__name__)其中的user…...