当前位置: 首页 > news >正文

扩散模型实战(四):从零构建扩散模型

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

本文以MNIST数据集为例,从零构建扩散模型,具体会涉及到如下知识点:

  • 退化过程(向数据中添加噪声)
  • 构建一个简单的UNet模型
  • 训练扩散模型
  • 采样过程分析

下面介绍具体的实现过程:

一、环境配置&python包的导入
     

最好有GPU环境,比如公司的GPU集群或者Google Colab,下面是代码实现:

# 安装diffusers库!pip install -q diffusers# 导入所需要的包import torchimport torchvisionfrom torch import nnfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom diffusers import DDPMScheduler, UNet2DModelfrom matplotlib import pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f'Using device: {device}')
# 输出Using device: cuda

此时会输出运行环境是GPU还是CPU

载MNIST数据集

       MNIST数据集是一个小数据集,存储的是0-9手写数字字体,每张图片都28X28的灰度图片,每个像素的取值范围是[0,1],下面加载该数据集,并展示部分数据:

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)x, y = next(iter(train_dataloader))print('Input shape:', x.shape)print('Labels:', y)plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
# 输出Input shape: torch.Size([8, 1, 28, 28])Labels: tensor([7, 8, 4, 2, 3, 6, 0, 2])

扩散模型的退化过程

       所谓退化过程,其实就是对输入数据加入噪声的过程,由于MNIST数据集的像素范围在[0,1],那么我们加入噪声也需要保持在相同的范围,这样我们可以很容易的把输入数据与噪声进行混合,代码如下:

def corrupt(x, amount):  """Corrupt the input `x` by mixing it with noise according to `amount`"""  noise = torch.rand_like(x)  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works  return x*(1-amount) + noise*amount

接下来,我们看一下逐步加噪的效果,代码如下:

# Plotting the input datafig, axs = plt.subplots(2, 1, figsize=(12, 5))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')# Adding noiseamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Plottinf the noised versionaxs[1].set_title('Corrupted data (-- amount increases -->)')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

       从上图可以看出,从左到右加入的噪声逐步增多,当噪声量接近1时,数据看起来像纯粹的随机噪声。

构建一个简单的UNet模型

       UNet模型与自编码器有异曲同工之妙,UNet最初是用于完成医学图像中分割任务的,网络结构如下所示:

代码如下:

class BasicUNet(nn.Module):    """A minimal UNet implementation."""    def __init__(self, in_channels=1, out_channels=1):        super().__init__()        self.down_layers = torch.nn.ModuleList([             nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),            nn.Conv2d(32, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 64, kernel_size=5, padding=2),        ])        self.up_layers = torch.nn.ModuleList([            nn.Conv2d(64, 64, kernel_size=5, padding=2),            nn.Conv2d(64, 32, kernel_size=5, padding=2),            nn.Conv2d(32, out_channels, kernel_size=5, padding=2),         ])        self.act = nn.SiLU() # The activation function        self.downscale = nn.MaxPool2d(2)        self.upscale = nn.Upsample(scale_factor=2)    def forward(self, x):        h = []        for i, l in enumerate(self.down_layers):            x = self.act(l(x)) # Through the layer and the activation function            if i < 2: # For all but the third (final) down layer:              h.append(x) # Storing output for skip connection              x = self.downscale(x) # Downscale ready for the next layer                      for i, l in enumerate(self.up_layers):            if i > 0: # For all except the first up layer              x = self.upscale(x) # Upscale              x += h.pop() # Fetching stored output (skip connection)            x = self.act(l(x)) # Through the layer and the activation function                    return x

我们来检验一下模型输入输出的shape变化是否符合预期,代码如下:

net = BasicUNet()x = torch.rand(8, 1, 28, 28)net(x).shape
# 输出torch.Size([8, 1, 28, 28])

再来看一下模型的参数量,代码如下:

sum([p.numel() for p in net.parameters()])
# 输出309057

至此,已经完成数据加载和UNet模型构建,当然UNet模型的结构可以有不同的设计。

、扩散模型训练

        扩散模型应该学习什么?其实有很多不同的目标,比如学习噪声,我们先以一个简单的例子开始,输入数据为带噪声的MNIST数据,扩散模型应该输出对应的最佳数字预测,因此学习的目标是预测值与真实值的MSE,训练代码如下:

# Dataloader (you can mess with batch size)batch_size = 128train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# How many runs through the data should we do?n_epochs = 3# Create the networknet = BasicUNet()net.to(device)# Our loss finctionloss_fn = nn.MSELoss()# The optimizeropt = torch.optim.Adam(net.parameters(), lr=1e-3) # Keeping a record of the losses for later viewinglosses = []# The training loopfor epoch in range(n_epochs):    for x, y in train_dataloader:        # Get some data and prepare the corrupted version        x = x.to(device) # Data on the GPU        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts        noisy_x = corrupt(x, noise_amount) # Create our noisy x        # Get the model prediction        pred = net(noisy_x)        # Calculate the loss        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?        # Backprop and update the params:        opt.zero_grad()        loss.backward()        opt.step()        # Store the loss for later        losses.append(loss.item())    # Print our the average of the loss values for this epoch:    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')# View the loss curveplt.plot(losses)plt.ylim(0, 0.1);
# 输出Finished epoch 0. Average loss for this epoch: 0.024689Finished epoch 1. Average loss for this epoch: 0.019226Finished epoch 2. Average loss for this epoch: 0.017939

训练过程的loss曲线如下图所示:

六、扩散模型效果评估

我们选取一部分数据来评估一下模型的预测效果,代码如下:

#@markdown Visualizing model predictions on noisy inputs:# Fetch some datax, y = next(iter(train_dataloader))x = x[:8] # Only using the first 8 for easy plotting# Corrupt with a range of amountsamount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruptionnoised_x = corrupt(x, amount)# Get the model predictionswith torch.no_grad():  preds = net(noised_x.to(device)).detach().cpu()# Plotfig, axs = plt.subplots(3, 1, figsize=(12, 7))axs[0].set_title('Input data')axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')axs[1].set_title('Corrupted data')axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')axs[2].set_title('Network Predictions')axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');

从上图可以看出,对于噪声量较低的输入,模型的预测效果是很不错的,当amount=1时,模型的输出接近整个数据集的均值,这正是扩散模型的工作原理。

Note:我们的训练并不太充分,读者可以尝试不同的超参数来优化模型。

相关文章:

扩散模型实战(四):从零构建扩散模型

推荐阅读列表&#xff1a; 扩散模型实战&#xff08;一&#xff09;&#xff1a;基本原理介绍 扩散模型实战&#xff08;二&#xff09;&#xff1a;扩散模型的发展 扩散模型实战&#xff08;三&#xff09;&#xff1a;扩散模型的应用 本文以MNIST数据集为例&#xff0c;从…...

YOLOv5、YOLOv8改进:S2注意力机制

目录 1.简介 2.YOLOv5改进 2.1增加以下S2-MLPv2.yaml文件 2.2common.py配置 2.3yolo.py配置 1.简介 S2-MLPv2注意力机制 最近&#xff0c;出现了基于 MLP 的视觉主干。与 CNN 和视觉Transformer相比&#xff0c;基于 MLP 的视觉架构具有较少的归纳偏差&#xff0c;在图像识…...

LeetCode 542. 01 Matrix【多源BFS】中等

本文属于「征服LeetCode」系列文章之一&#xff0c;这一系列正式开始于2021/08/12。由于LeetCode上部分题目有锁&#xff0c;本系列将至少持续到刷完所有无锁题之日为止&#xff1b;由于LeetCode还在不断地创建新题&#xff0c;本系列的终止日期可能是永远。在这一系列刷题文章…...

使用open cv进行角度测量

使用open cv进行角度测量 用了一点初中数学的知识&#xff0c;准确度&#xff0c;跟鼠标点的准不准有关系&#xff0c;话不多说直接上代码 import cv2 import mathpath "test.jpg" img cv2.imread(path) pointsList []def mousePoint(event, x, y, flags, param…...

java 线程池实现多线程处理list数据

newFixedThreadPool线程池实现多线程 List<PackageAgreementEntity> entityList new CopyOnWriteArrayList<>();//多线程 10个线程//int threadNum 10;int listSize 300;List<List<PackageAgreementDto>> splitData Lists.partition(packageAgre…...

Centos安装Docker

Centos安装 Docker 从 2017 年 3 月开始 docker 在原来的基础上分为两个分支版本: Docker CE 和 Docker EE。 Docker CE 即社区免费版&#xff0c;Docker EE 即企业版&#xff0c;强调安全&#xff0c;但需付费使用。 本文介绍 Docker CE 的安装使用。 移除旧的版本&#x…...

Unity启动项目无反应的解决

文章首发见博客&#xff1a;https://mwhls.top/4803.html。 无图/格式错误/后续更新请见首发页。 更多更新请到mwhls.top查看 欢迎留言提问或批评建议&#xff0c;私信不回。 摘要&#xff1a;通过退还并重新载入许可证以解决Unity项目启动无反应问题。 场景 Unity Hub启动项目…...

2.3 opensbi: riscv: opensbi源码解析

文章目录 3. sbi_init()函数4. init_coldboot()函数4.1 sbi_scratch_init()函数4.2 sbi_domain_init()函数4.3 sbi_scratch_alloc_offset()函数4.4 sbi_hsm_init()函数4.5 sbi_platform_early_init()函数3. sbi_init()函数 函数位置:lib/sbi/sbi_init.c函数参数:scratch为每个…...

点破ResNet残差网络的精髓

卷积神经网络在实际训练过程中&#xff0c;不可避免会遇到一个问题&#xff1a;随着网络层数的增加&#xff0c;模型会发生退化。    换句话说&#xff0c;并不是网络层数越多越好&#xff0c;为什么会这样&#xff1f; 不是说网络越深&#xff0c;提取的特征越多&#xff…...

Ubuntu服务器service版本初始化

下载 下载路径 官网&#xff1a;https://cn.ubuntu.com/ 下载路径&#xff1a;https://cn.ubuntu.com/download 服务器&#xff1a;https://cn.ubuntu.com/download/server/step1 点击下载&#xff08;22.04.3&#xff09;&#xff1a;https://cn.ubuntu.com/download/server…...

re学习(33)攻防世界-secret-galaxy-300(脑洞题)

下载压缩包&#xff1a; 下载链接&#xff1a;https://adworld.xctf.org.cn/challenges/list 参考文章&#xff1a;攻防世界逆向高手题之secret-galaxy-300_沐一 林的博客-CSDN博客 发现这只是三个同一类型文件的三个不同版本而已&#xff0c;一个windows32位exe&#xff0…...

Mybatis Plus中使用LambdaQueryWrapper进行分页以及模糊查询对比传统XML方式进行分页

传统的XML分页以及模糊查询操作 传统的XML方式只能使用limit以及offset进行分页&#xff0c;通过判断name和bindState是否为空&#xff0c;不为空则拼接条件。 List<SanitationCompanyStaff> getSanitationStaffInfo(Param("name") String name,Param("bi…...

vue中push和resolve的区别

import { useRouter } from vue-router;const routeuseRouter()route.push({path:/test,query:{name:1}})import { useRouter } from vue-router;const routeuseRouter()const urlroute.resolve({path:/test,query:{name:1}})window.open(url.href)比较上述代码会发现,resolve能…...

详解RFC 3550文档-1

1. 介绍 rfc 3550描述了实时传输协议RTP。RTP提供端到端的网络传输功能,适用于通过组播或单播网络服务传输实时数据(如音频、视频或仿真数据)的应用。 TP本身不提供任何机制来确保及时交付或提供其他服务质量保证,而是依赖于较低层的服务来完成这些工作。它不保证传输或防止…...

Go 与 Rust

目录 1. Go 与 Rust 1. Go 与 Rust 一位挺 Rust 的网友说道: “我也为这个选择烦恼了很久。最终 Rust 胜出了。首先, 我感觉 Rust 更接近于以前 Pascal 时代的东西, 你可以控制一切; 其次, 如果 wasm 和相关技术大爆发, Rust 将是一个更安全的选择; 然后, 我们已经有了 Python…...

Android Studio实现读取本地相册文件并展示

目录 原文链接效果 代码activity_main.xmlMainActivity 原文链接 效果 代码 activity_main.xml 需要有一个按钮和image来展示图片 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk…...

python的全局解释锁(GIL)

一、介绍 全局解释锁&#xff08;Global Interpreter Lock&#xff0c;GIL&#xff09;是在某些编程语言的解释器中使用的一种机制。在Python中&#xff0c;GIL是为了保证解释器线程安全而引入的。 GIL的作用是在解释器的执行过程中&#xff0c;确保同一时间只有一个线程可以…...

小程序swiper一个轮播显示一个半内容且实现无缝滚动

效果图&#xff1a; wxml&#xff08;无缝滚动&#xff1a;circular"true"&#xff09;&#xff1a; <!--components/tool_version/tool_version.wxml--> <view class"tool-version"><swiper class"tool-version-swiper" circul…...

【自然语言处理】关系抽取 —— SimpleRE 讲解

SimpleRE 论文信息 标题:An Embarrassingly Simple Model for Dialogue Relation Extraction 作者:Fuzhao Xue 期刊:ICASSP 2022 发布时间与更新时间:2020.12.27 2022.01.25 主题:自然语言处理、关系抽取、对话场景、BERT arXiv:[2012.13873] An Embarrassingly Simple M…...

【O2O领域】Axure外卖订餐骑手端APP原型图,外卖众包配送原型设计图

作品概况 页面数量&#xff1a;共 110 页 兼容软件&#xff1a;Axure RP 9/10&#xff0c;不支持低版本 应用领域&#xff1a;外卖配送、生鲜配送 作品申明&#xff1a;页面内容仅用于功能演示&#xff0c;无实际功能 作品特色 本品为外卖订餐骑手端APP原型设计图&#x…...

DataGridView keydown事件无法在C#中工作

原因&#xff1a;单元格内编辑文本时,DataGridView keydown事件不起作用。每当单元格处于编辑模式时,其托管控件就会接收KeyDown事件而不是DataGridView包含它的父级.这就是为什么当单元格未处于编辑模式时(即使它被选中),键盘快捷键正常工作,因为DataGridView控件本身会收到Ke…...

【ElasticSearch】一键安装ElasticSearch与Kibana以及解决遇到的问题

目录 一、安装ES 二、安装Kibana 三、遇到的问题 一、安装ES 按顺序复制即可 docker network create es-net # 创建网络 docker pull images:7.12.1 # 拉取镜像 mkdir -p /root/es/data # 创建数据卷 mkdir -p /root/es/plugins # 创建数据卷 chmod 777 /root/es/** # 设置权…...

电商数据采集和数据分析

不管是做渠道价格的治理&#xff0c;还是做窜货、假货的打击&#xff0c;都需要品牌对线上数据尽数掌握&#xff0c;准确的数据是驱动服务的关键&#xff0c;所以做好电商数据的采集和分析非常重要。 当线上链接较多&#xff0c;品牌又需要监测线上数据时&#xff0c;单靠人工肯…...

react 11之 router6路由 (两种路由模式、两种路由跳转、两种传参与接收参数、嵌套路由,layout组件、路由懒加载)

目录 react路由1&#xff1a;安装和两种模式react路由2&#xff1a;两种路由跳转 &#xff08; 命令式与编程式&#xff09;2-1 路由跳转-命令式2-2 路由跳转-编程式 - 函数组件2-2-1 app.jsx2-2-2 page / Home.jsx2-2-3 page / About.jsx2-2-4 效果 react路由3&#xff1a;函数…...

Golang 基础语法问答

使用值为 nil 的 slice、map 会发生什么&#xff1f; 允许对值为 nil 的 slice 添加元素&#xff0c;但是对值为 nil 的 map 添加元素时会造成运行时 panic。 // map错误示例 func main() {var m map[string]intm["one"] 1 // error: panic: assignment to entry …...

冠达管理:哪里查中报预增?

中报季行将到来&#xff0c;投资者开端重视公司的成绩体现。中报预增是投资者最关心的论题之一&#xff0c;因为这意味着公司未来成绩的增加潜力。但是&#xff0c;怎么查找中报预增的信息呢&#xff1f;本文将从多个视点分析这个问题。 1.证券交易所网站 证券交易所网站是投资…...

docker安装Oracle11gR2

文章目录 目录 文章目录 前言 一、前期准备 二、具体配置 2.1 配置oracle容器 2.2 配置navicat连接 总结 前言 使用docker模拟oracle环境 一、前期准备 安装好docker #拉取镜像 docker pull registry.cn-hangzhou.aliyuncs.com/helowin/oracle_11g #启动 docker run -…...

unity 之 Input.GetMouseButtonDown 的使用

文章目录 Input.GetMouseButtonDown Input.GetMouseButtonDown 当涉及到处理鼠标输入的时候&#xff0c;Input.GetMouseButtonDown 是一个常用的函数。它可以用来检测鼠标按键是否在特定帧被按下。下面我会详细介绍这个函数&#xff0c;并举两个例子说明如何使用它。 函数签名…...

链游再进化 Web3版CSGO来袭

过去几年&#xff0c;游戏开发者们一直希望借Web3这个价值流通网络&#xff0c;改造传统游戏的经济系统&#xff0c;将虚拟资产的掌管权交给用户&#xff0c;让资产自由地在市场流通。 Web3游戏发展史上&#xff0c;涌现过CryptoKitties、Axie Infinity两大爆款&#xff0c;但…...

WordPress用于您的企业网站的优点和缺点

如今&#xff0c;WordPress 被广泛认为是一个可靠、可扩展且安全的平台&#xff0c;能够为商业网站提供支持。然而&#xff0c;许多人质疑 WordPress 是否是适合企业的平台。 这就是我们创建本指南的原因。通过探索 WordPress 的优点和缺点&#xff0c;您可以确定世界上最受欢…...

抖音小程序开发公司/深圳seo优化外包公司

Spring2.5 的事物配置&#xff1a;<!-- 配置事务通知 --><tx:advice id"txAdvice" transactionManagertxManager><!-- REQUIRED:表示当前方法必须运行在一个事物中&#xff0c;没有事物则创建一个。SUPPORTS&#xff1a;表示当前方法不需要事物-->…...

mac wordpress数据库文件路径/今日大事件新闻

--- 说明闪回数据库 --- 使用闪回表将表内容还原到过去的特定时间点 --- 从删除表中进行恢复 --- 使用闪回查询查看截止到任一时间点的数据库内容 --- 使用闪回版本查询查看某一行在一段时间内的各个版本 --- 使用闪回事务查询查看事务处理历史记录或行 优点&#xff1a; …...

企业网站管理系统的设计与实现/淘宝关键词查询工具

七、radio 单选框radio-group 属性名类型默认值说明bindchangeEventHandle radio-group中的选中项发生变化时触发change事件&#xff0c;event.detai {value : 选中项radio的valueradio 属性名类型默认值说明valueString radio标识。当该radio选中时&#xff0c;radio-group的…...

一个在线做笔记的网站/百度如何做广告

一、操作系统提代一个开发接口给硬件开发商&#xff0c;让他们可以根据这个接口设计可以驱动他们硬件的驱动程序。二、应用程序是参考操作系统提供开发接口所开发出来的软件三、核心功能&#xff1a;系统呼叫接口、程序管理、内存管理、文件系统管理、装置驱动四、CPU每次能够处…...

个人网站建设程序设计/google ads

Shell文本处理三剑客之一awk(2) 表达式与其他编程语言一样&#xff0c;awk表达式用于存储&#xff0c;操作和获取数据 一个awk表达式可由数值&#xff0c;字符常量&#xff0c;变量&#xff0c;操作符函数和正则表达式自由组合而成 变量是一个值的标识符&#xff0c;定义awk变量…...

桂林生活网疫情最新消息/搜索引擎关键词优化

blog地址&#xff1a;https://blog.friddle.me/post/frida-js-de-retrofit-si-lu-he-chang-shi/开始博客又搞来搞去。本来准备在知乎上了。不过可以在这里写。然后知乎上再拷贝一份 Retrofit是一个很牛逼的框架/Frida也是。我作为新手。通过hack某个App一周多。也算正式入门了目…...