当前位置: 首页 > news >正文

深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

导入python包

import mathimport torch
import torch.nn as nn
import torch.nn.functional as F

 silu激活函数

class SiLU(nn.Module):  # SiLU激活函数@staticmethoddef forward(x):return x * torch.sigmoid(x)

归一化设置

def get_norm(norm, num_channels, num_groups):if norm == "in":return nn.InstanceNorm2d(num_channels, affine=True)elif norm == "bn":return nn.BatchNorm2d(num_channels)elif norm == "gn":return nn.GroupNorm(num_groups, num_channels)elif norm is None:return nn.Identity()else:raise ValueError("unknown normalization type")

 计算时间步长的位置嵌入,一半为sin,一半为cos

class PositionalEmbedding(nn.Module):def __init__(self, dim, scale=1.0):super().__init__()assert dim % 2 == 0self.dim = dimself.scale = scaledef forward(self, x):device      = x.devicehalf_dim    = self.dim // 2emb = math.log(10000) / half_dimemb = torch.exp(torch.arange(half_dim, device=device) * -emb)# x * self.scale和emb外积emb = torch.outer(x * self.scale, emb)emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb

 上下采样层设置

class Downsample(nn.Module):def __init__(self, in_channels):super().__init__()self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)def forward(self, x, time_emb, y):if x.shape[2] % 2 == 1:raise ValueError("downsampling tensor height should be even")if x.shape[3] % 2 == 1:raise ValueError("downsampling tensor width should be even")return self.downsample(x)class Upsample(nn.Module):def __init__(self, in_channels):super().__init__()self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"),nn.Conv2d(in_channels, in_channels, 3, padding=1),)def forward(self, x, time_emb, y):return self.upsample(x)

 使用Self-Attention注意力机制,做一个全局的Self-Attention

class AttentionBlock(nn.Module):def __init__(self, in_channels, norm="gn", num_groups=32):super().__init__()self.in_channels = in_channelsself.norm = get_norm(norm, in_channels, num_groups)self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)self.to_out = nn.Conv2d(in_channels, in_channels, 1)def forward(self, x):b, c, h, w  = x.shapeq, k, v     = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)q = q.permute(0, 2, 3, 1).view(b, h * w, c)k = k.view(b, c, h * w)v = v.permute(0, 2, 3, 1).view(b, h * w, c)dot_products = torch.bmm(q, k) * (c ** (-0.5))assert dot_products.shape == (b, h * w, h * w)attention   = torch.softmax(dot_products, dim=-1)out         = torch.bmm(attention, v)assert out.shape == (b, h * w, c)out         = out.view(b, h, w, c).permute(0, 3, 1, 2)return self.to_out(out) + x

 用于特征提取的残差结构

class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,norm="gn", num_groups=32, use_attention=False,):super().__init__()self.activation = activationself.norm_1 = get_norm(norm, in_channels, num_groups)self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)self.norm_2 = get_norm(norm, out_channels, num_groups)self.conv_2 = nn.Sequential(nn.Dropout(p=dropout), nn.Conv2d(out_channels, out_channels, 3, padding=1),)self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else Noneself.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else Noneself.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)def forward(self, x, time_emb=None, y=None):out = self.activation(self.norm_1(x))# 第一个卷积out = self.conv_1(out)# 对时间time_emb做一个全连接,施加在通道上if self.time_bias is not None:if time_emb is None:raise ValueError("time conditioning was specified but time_emb is not passed")out += self.time_bias(self.activation(time_emb))[:, :, None, None]# 对种类y_emb做一个全连接,施加在通道上if self.class_bias is not None:if y is None:raise ValueError("class conditioning was specified but y is not passed")out += self.class_bias(y)[:, :, None, None]out = self.activation(self.norm_2(out))# 第二个卷积+残差边out = self.conv_2(out) + self.residual_connection(x)# 最后做个Attentionout = self.attention(out)return out

 U-Net模型设计

class UNet(nn.Module):def __init__(self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,):super().__init__()# 使用到的激活函数,一般为SILUself.activation = activation# 是否对输入进行paddingself.initial_pad = initial_pad# 需要去区分的类别数self.num_classes = num_classes# 对时间轴输入的全连接层self.time_mlp = nn.Sequential(PositionalEmbedding(base_channels, time_emb_scale),nn.Linear(base_channels, time_emb_dim),nn.SiLU(),nn.Linear(time_emb_dim, time_emb_dim),) if time_emb_dim is not None else None# 对输入图片的第一个卷积self.init_conv  = nn.Conv2d(img_channels, base_channels, 3, padding=1)# self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征# 然后利用Downsample降低特征图的高宽self.downs      = nn.ModuleList()self.ups        = nn.ModuleList()# channels指的是每一个模块处理后的通道数# now_channels是一个中间变量,代表中间的通道数channels        = [base_channels]now_channels    = base_channelsfor i, mult in enumerate(channel_mults):out_channels = base_channels * multfor _ in range(num_res_blocks):self.downs.append(ResidualBlock(now_channels, out_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))now_channels = out_channelschannels.append(now_channels)if i != len(channel_mults) - 1:self.downs.append(Downsample(now_channels))channels.append(now_channels)# 可以看作是特征整合,中间的一个特征提取模块self.mid = nn.ModuleList([ResidualBlock(now_channels, now_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,norm=norm, num_groups=num_groups, use_attention=True,),ResidualBlock(now_channels, now_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=False,),])# 进行上采样,进行特征融合for i, mult in reversed(list(enumerate(channel_mults))):out_channels = base_channels * multfor _ in range(num_res_blocks + 1):self.ups.append(ResidualBlock(channels.pop() + now_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))now_channels = out_channelsif i != 0:self.ups.append(Upsample(now_channels))assert len(channels) == 0self.out_norm = get_norm(norm, base_channels, num_groups)self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)def forward(self, x, time=None, y=None):# 是否对输入进行paddingip = self.initial_padif ip != 0:x = F.pad(x, (ip,) * 4)# 对时间轴输入的全连接层if self.time_mlp is not None:if time is None:raise ValueError("time conditioning was specified but tim is not passed")time_emb = self.time_mlp(time)else:time_emb = Noneif self.num_classes is not None and y is None:raise ValueError("class conditioning was specified but y is not passed")# 对输入图片的第一个卷积x = self.init_conv(x)# skips用于存放下采样的中间层skips = [x]for layer in self.downs:x = layer(x, time_emb, y)skips.append(x)# 特征整合与提取for layer in self.mid:x = layer(x, time_emb, y)# 上采样并进行特征融合for layer in self.ups:if isinstance(layer, ResidualBlock):x = torch.cat([x, skips.pop()], dim=1)x = layer(x, time_emb, y)# 上采样并进行特征融合x = self.activation(self.out_norm(x))x = self.out_conv(x)if self.initial_pad != 0:return x[:, :, ip:-ip, ip:-ip]else:return x

参考链接:GitCode - 开发者的代码家园icon-default.png?t=N7T8https://gitcode.com/bubbliiiing/ddpm-pytorch/tree/master?utm_source=csdn_github_accelerator&isLogin=1

相关文章:

深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

导入python包 import mathimport torch import torch.nn as nn import torch.nn.functional as F silu激活函数 class SiLU(nn.Module): # SiLU激活函数staticmethoddef forward(x):return x * torch.sigmoid(x) 归一化设置 def get_norm(norm, num_channels, num_groups)…...

C#使用RabbitMQ-2_详解工作队列模式

简介 🍀RabbitMQ中的工作队列模式是指将任务分配给多个消费者并行处理。在工作队列模式中,生产者将任务发送到RabbitMQ交换器,然后交换器将任务路由到一个或多个队列。消费者从队列中获取任务并进行处理。处理完成后,消费者可以向…...

Day37 56合并区间 738单调递增的数字 968监控二叉树

56 合并区间 给出一个区间的集合&#xff0c;请合并所有重叠的区间。 示例 1: 输入: intervals [[1,3],[2,6],[8,10],[15,18]]输出: [[1,6],[8,10],[15,18]]解释: 区间 [1,3] 和 [2,6] 重叠, 将它们合并为 [1,6]. class Solution { public:vector<vector<int>>…...

【Android】在WSA安卓子系统中进行新实验性功能试用与抓包(2311.4.5.0)

前言 在根据几篇22和23的WSA抓包文章进行尝试时遇到了问题&#xff0c;同时发现新版Wsa的一些实验性功能能优化抓包配置时的一些步骤&#xff0c;因而写下此篇以作记录。 Wsa版本&#xff1a;2311.40000.5.0 本文出现的项目&#xff1a; MagiskOnWSALocal MagiskTrustUserCer…...

【服务器】服务器的管理口和网口

服务器通常会有两种不同类型的网络接口&#xff0c;即管理口&#xff08;Management Port&#xff09;和网口&#xff08;Ethernet Port&#xff09;&#xff0c;它们的作用和用途不同。 一、管理口 管理口通常是用于服务器管理的网络接口&#xff0c;也被称为外带网卡或带外接…...

一个小例子,演示函数指针

结构体里经常看到函数指针的写法&#xff0c;函数指针其实就是函数的名字。但是结构体里你要是直接把一个函数摆上去&#xff0c;那就变成成员变量&#xff0c;就会发生混乱 1. 函数指针 #include <unistd.h> #include <stdio.h>struct Kiwia{void (*func)(int )…...

python12-Python的字符串之使用input获取用户输入

input()函数用于向用户生成一条提示,然后获取用户输入的内容。由于input0函数总会将用户输入的内容放入字符串中,因此用户可以输入任何内容,input()函数总是返回一个字符串。例如如下程序。 # !/usr/bin/env python# -*- coding: utf-8 -*-# @Time : 2024/01# @Author : Lao…...

【代码随想录-数组】移除元素

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学习,不断总结,共同进步,活到老学到老导航 檀越剑指大厂系列:全面总结 jav…...

springboot事务管理

/*spring事务管理注解:Transactional位置:业务(service)层的方法上、类上、接口上作用:将当前方法交给spring进行事务管理&#xff0c;方法执行前&#xff0c;开启事务:成功执行完毕&#xff0c;提交事务:出现常&#xff0c;回滚事务需要在配置文件是加上开启spring事务yml文件…...

数据结构——链式二叉树(2)

目录 &#x1f341;一、二叉树的销毁 &#x1f341;二、在二叉树中查找某个数&#xff0c;并返回该结点 &#x1f341;三、LeetCode——检查两棵二叉树是否相等 &#x1f315;&#xff08;一&#xff09;、题目链接&#xff1a;100. 相同的树 - 力扣&#xff08;LeetCode&a…...

spring-boot-starter-validation常用注解

文章目录 一、使用二、常用注解三、Valid or Validated &#xff1f;四、分组校验1. 分组校验的基本概念2. 定义验证组3. 应用分组到模型4. 在控制器中使用分组5. 总结 一、使用 要使用这些注解&#xff0c;首先确保在你的 Spring Boot 应用的 pom.xml 文件中添加了 spring-bo…...

AF700 NHS 酯,AF 700 Succinimidyl Ester,一种明亮且具有光稳定性的近红外染料

AF700 NHS 酯&#xff0c;AF 700 Succinimidyl Ester&#xff0c;一种明亮且具有光稳定性的近红外染料&#xff0c;AF700-NHS-酯&#xff0c;具有水溶性和 pH 值不敏感性 您好&#xff0c;欢迎来到新研之家 文章关键词&#xff1a;AF700 NHS 酯&#xff0c;AF 700 Succinimid…...

C#常见内存泄漏

背景 在开发中由于对语言特性不了解或经验不足或疏忽&#xff0c;往往会造成一些低级bug。而内存泄漏就是最常见的一个&#xff0c;这个问题在测试过程中&#xff0c;因为操作频次低&#xff0c;而不能完全被暴露出来&#xff1b;而在正式使用时&#xff0c;由于使用次数增加&…...

Xmind安装到指定目录

Xmind安装到指定目录 默认情况下安装包自动引导安装在C盘&#xff08;注册表默认位置&#xff09; T1:修改注册表&#xff0c;比较麻烦 T2:安装时命令行指定安装位置&#xff0c;快捷省事 1&#xff09;下载安装包&#xff08;exe可执行文件&#xff09; 2&#xff09;安装…...

[GXYCTF2019]BabyUpload1

尝试各种文件&#xff0c;黑名单过滤后缀ph&#xff0c;content-type限制image/jpeg 内容过滤<?&#xff0c;木马改用<script languagephp>eval($_POST[cmdjs]);</script> 上传.htaccess将上传的文件当作php解析 蚁剑连接得到flag...

SpringBoot之分页查询的使用

背景 在业务中我们在前端总是需要展示数据&#xff0c;将后端得到的数据进行分页处理&#xff0c;通过pagehelper实现动态的分页查询&#xff0c;将查询页数和分页数通过前端发送到后端&#xff0c;后端使用pagehelper&#xff0c;底层是封装threadlocal得到页数和分页数并动态…...

【shell-10】shell实现的各种kafka脚本

kafka-shell工具 背景日志 log一.启动kafka->(start-kafka)二.停止kafka->(stop-kafka)三.创建topic->(create-topic)四.删除topic->(delete-topic)五.获取topic列表->(list-topic)六. 将文件数据 录入到kafka->(file-to-kafka)七.将kafka数据 下载到文件-&g…...

【模型压缩】模型剪枝详解

参考链接:https://zhuanlan.zhihu.com/p/635454943 https 文章目录 1. 前言1.1 为什么要进行模型剪枝1.2 为什么可以进行模型剪枝2. 剪枝方式的几种分类2.1 结构化剪枝 和 非结构化剪枝2.1.1 结构化剪枝2.1.2 非结构化剪枝2.2 静态剪枝与动态剪枝2.2.1 静态剪枝2.2.2 动态剪枝…...

Log4j2-01-log4j2 hello world 入门使用

拓展阅读 Log4j2 系统学习 Logback 系统学习 Slf4j Slf4j-02-slf4j 与 logback 整合 SLF4j MDC-日志添加唯一标识 分布式链路追踪-05-mdc 等信息如何跨线程? Log4j2 与 logback 的实现方式 日志开源组件&#xff08;一&#xff09;java 注解结合 spring aop 实现自动输…...

Mysql-日志介绍 日志配置

环境部署 docker run -d -p 3306:3306 --privilegedtrue -v $(pwd)/logs:/var/lib/logs -v $(pwd)/conf:/etc/mysql/conf.d -v $(pwd)/data:/var/lib/mysql -e MYSQL_ROOT_PASSWORD654321 --name mysql mysql:5.7运行指令的目录下新建好这些文件&#xff1a; 日志类型 日…...

计算机网络的体系结构的各层在整个过程中起到什么作用?

ps&#xff1a;本文章的图片内容来源都是来自于湖科大教书匠的视频&#xff0c;声明&#xff1a;仅供自己复习&#xff0c;里面加上了自己的理解 这里附上视频链接地址&#xff1a;1.6 计算机网络体系结构&#xff08;4&#xff09;—专用术语_哔哩哔哩_bilibili 目录 &#x…...

如何在业务代码中优雅的使用策略模式?

策略模式介绍 假设你正在开发一个电商平台&#xff0c;其中涉及到商品的折扣策略。优惠策略有很多种可能&#xff0c;如领取优惠券抵扣、返现促销、拼团优惠等。最初的实现可能会在购物车类中嵌入各种折扣逻辑&#xff0c;导致代码的可维护性和扩展性下降。 下面代码在业务开…...

“docker-credential-desktop.exe“: executable file not found in $PATH 错误解决

"docker-credential-desktop.exe": executable file not found in $PATH 错误解决 1. 错误信息和解决方法 1. 错误信息和解决方法 错误信息&#xff0c; error getting credentials - err: exec: "docker-credential-desktop.exe": executable file not …...

openssl3.2/test/certs - 055 - all DNS-like CNs allowed by CA1, no DNS SANs

文章目录 openssl3.2/test/certs - 055 - all DNS-like CNs allowed by CA1, no DNS SANs概述笔记END openssl3.2/test/certs - 055 - all DNS-like CNs allowed by CA1, no DNS SANs 概述 openssl3.2 - 官方demo学习 - test - certs 笔记 /*! * \file D:\my_dev\my_local_…...

长虹智能电视6000iD、6080iD、3000iD、U2系列等 ZLM61HiPJ机芯升级刷机方法,附刷机数据

机芯&#xff1a;ZLM61HiPJ 适用机型&#xff1a;UD39B6000iD、UD42B6000iD、UD50B6000iD、UD55B6000iD、UD42C6000iD、UD42C6080iD、UD49C6000iD、UD49C6080iD、UD55C6000iD、UD55C6080iD、UD50C6000iD、UD58C3000iD、42U2 LE42C19S-UD、LE50C29SD-UD、 UD49C6000iD(LJM2W)、…...

六、VTK创建平面vtkPlaneSource

vtkPlaneSource创建位于平面中的四边形数组 先看看效果图: vtkPlaneSource 创建一个 m x n 个四边形数组,这些四边形在平面中排列为规则平铺。通过指定一个原点来定义平面,然后指定另外两个点,这两个点与原点一起定义平面的两个轴。这些轴不必是正交的 - 因此您可以创建平行…...

LiveGBS流媒体平台GB/T28181常见问题-如何配置使用自己已有的redis服务替换redis版本升级redis版本

LiveGBS如何配置使用自己已有的redis服务替换redis版本升级redis版本 1、Redis服务2、如何切换REDIS?2.1、停止启动REDIS2.2、配置信令服务2.3、配置流媒体服务2.4、启动 3、搭建GB28181视频直播平台 1、Redis服务 在LivGBS中Redis作为数据交换、数据订阅、数据发布的高速缓存…...

stm32产品架构

文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据 总结 前言 起因是我在看野火的ucosiii&#xff0c;然后他是基于i.mx芯片。然后我就很疑惑i.mx是什么芯片&#xff0c;看了下好像是ARM-M7(或者叫ARMCM7)架构的芯片。然后我又疑惑ARM-M7又是什么架…...

数据结构——双链表

双链表中节点类型的描述&#xff1a; 双链表的初始化&#xff08;带头结点&#xff09; 、 双链表的插入操作 后插操作 InsertNextDNode(p, s): 在p结点后插入s结点 按位序插入操作&#xff1a; 思路&#xff1a;从头结点开始&#xff0c;找到某个位序的前驱结点&#xff…...

Git 对文件名大小写不敏感的问题解决方案

目录 一、Git 对文件名大小写不敏感1.1 问题描述1.2 原因分析1.3 解决方案方式一&#xff1a;使用git命令进行修改方式二&#xff1a;关闭git 忽略大小写配置 &#xff08;可以当前项目设置&#xff0c;也可以全局设置 --global&#xff09; 二、新的问题&#xff08;重复的目录…...

wordpress edit_post_link/营销网站建设培训学校

在一个View中&#xff0c;系统提供了scrollTo、scrollBy两种方式来改变一个View的位置。这两个方法的区别非常好理解&#xff0c;与英文中To与By的区别类似&#xff0c;scrollTo(x, y)标识移动到一个具体的坐标点(x, y)&#xff0c;而scrollBy(dx, dy)表示移动的增量为dx、dy。…...

网站没后台怎么修改类容/如何在百度上添加自己的店铺

一、什么是库?在windows平台和linux平台下都大量存在着库。一般是软件作者为了发布方便、替换方便或二次开发目的&#xff0c;而发布的一组可以单独与应用程序进行compile time或runtime链接的二进制可重定位目标码文件。本质上来说库是一种可执行代码的二进制形式&#xff0c…...

盘锦做网站选哪家好/公司软文

attr()方法是用于设置标签的属性&#xff0c;比如src&#xff0c;href&#xff0c;title&#xff1b;&#xff08;这些更多的是元素的基本属性&#xff0c;HTML的属性&#xff09;&#xff1b; 目录 一&#xff1a;操作元素属性 &#xff08;1&#xff09;attr()方法&#x…...

做营销型网站费用/百度网盘资源搜索引擎

MyBatis SqlSessionFactory和SqlSession 一&#xff0c;SqlSessionFactory Mybatis提供了构造器SqlSessionFactoryBuilder来生成SqlSessionFactory。 在 MyBatis 中&#xff0c;既可以通过读取配置的 XML 文件的形式生成 SqlSessionFactory&#xff0c;也可以通过Java代码生…...

做纺织机械的网站域名/今日国内新闻最新消息10条新闻

♣题目部分在Oracle中&#xff0c;Hash Join是不是有排序&#xff1f;Hash Join会在什么时候慢&#xff1f;♣答案部分哈希连接&#xff08;Hash Join&#xff0c;HJ&#xff09;自身不需要排序&#xff0c;这是区别排序合并连接&#xff08;Sort Merge Join&#xff0c;SMJ&am…...

有专业做淘宝网站的美工吗/高端网站设计定制

Path类 提供静态方法&#xff0c;完成路径字符串的常见操作 例如在C盘的文件夹a下的b文件夹下的1.mp3文件 C:\a\b\1.mp3 一.获取信息的方法&#xff1a; 1.获得路径&#xff1a;Path.GetDirectoryName(路径); 结果&#xff1a;C:\a\b 获得文件名&#xff1a;Path.GetFileName(路…...