当前位置: 首页 > news >正文

Diffusion中的Unet (DIMP)

针对UNet2DConditionModel模型

查看Unet的源码,得知Unet的down,mid,up blocks的类型分别是:

down_block_types: Tuple[str] = ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2D",),mid_block_type: str = "UNetMidBlock2DCrossAttn",up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")

查看一下down  下采样的get_down_block方法:

def get_down_block(down_block_type,num_layers,in_channels,out_channels,temb_channels,add_downsample,resnet_eps,resnet_act_fn,attn_num_head_channels,resnet_groups=None,cross_attention_dim=None,downsample_padding=None,dual_cross_attention=False,use_linear_projection=False,only_cross_attention=False,upcast_attention=False,resnet_time_scale_shift="default",
):down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_typeif down_block_type == "DownBlock2D":return DownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "ResnetDownsampleBlock2D":return ResnetDownsampleBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownBlock2D":return AttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "CrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")return CrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention,upcast_attention=upcast_attention,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SimpleCrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")return SimpleCrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SkipDownBlock2D":return SkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnSkipDownBlock2D":return AttnSkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "DownEncoderBlock2D":return DownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownEncoderBlock2D":return AttnDownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)raise ValueError(f"{down_block_type} does not exist.")

 我们看一下该Unet的forward函数:

def forward(self,sample: torch.FloatTensor,timestep: Union[torch.Tensor, float, int],encoder_hidden_states: torch.Tensor,class_labels: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,cross_attention_kwargs: Optional[Dict[str, Any]] = None,return_dict: bool = True,) -> Union[UNet2DConditionOutput, Tuple]:r"""Args:sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensortimestep (`torch.FloatTensor` or `float` or `int`): (batch) timestepsencoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden statesreturn_dict (`bool`, *optional*, defaults to `True`):Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.Returns:[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. Whenreturning a tuple, the first element is the sample tensor."""# By default samples have to be AT least a multiple of the overall upsampling factor.# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).# However, the upsampling interpolation output size can be forced to fit any upsampling size# on the fly if necessary.default_overall_up_factor = 2**self.num_upsamplers# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`forward_upsample_size = Falseupsample_size = Noneif any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):logger.info("Forward upsample size to force interpolation output size.")forward_upsample_size = True# prepare attention_maskif attention_mask is not None:attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0attention_mask = attention_mask.unsqueeze(1)# 0. center input if necessaryif self.config.center_input_sample:sample = 2 * sample - 1.0# 1. timetimesteps = timestepif not torch.is_tensor(timesteps):# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can# This would be a good case for the `match` statement (Python 3.10+)is_mps = sample.device.type == "mps"if isinstance(timestep, float):dtype = torch.float32 if is_mps else torch.float64else:dtype = torch.int32 if is_mps else torch.int64timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)elif len(timesteps.shape) == 0:timesteps = timesteps[None].to(sample.device)# broadcast to batch dimension in a way that's compatible with ONNX/Core MLtimesteps = timesteps.expand(sample.shape[0])t_emb = self.time_proj(timesteps)# timesteps does not contain any weights and will always return f32 tensors# but time_embedding might actually be running in fp16. so we need to cast here.# there might be better ways to encapsulate this.t_emb = t_emb.to(dtype=self.dtype)emb = self.time_embedding(t_emb)if self.class_embedding is not None:if class_labels is None:raise ValueError("class_labels should be provided when num_class_embeds > 0")if self.config.class_embed_type == "timestep":class_labels = self.time_proj(class_labels)class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)emb = emb + class_emb# 2. pre-processsample = self.conv_in(sample)# 3. downdown_block_res_samples = (sample,)for downsample_block in self.down_blocks:if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:sample, res_samples = downsample_block(hidden_states=sample,temb=emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)else:sample, res_samples = downsample_block(hidden_states=sample, temb=emb)down_block_res_samples += res_samples# 4. midsample = self.mid_block(sample,emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)# 5. upfor i, upsample_block in enumerate(self.up_blocks):is_final_block = i == len(self.up_blocks) - 1res_samples = down_block_res_samples[-len(upsample_block.resnets) :]down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]# if we have not reached the final block and need to forward the# upsample size, we do it hereif not is_final_block and forward_upsample_size:upsample_size = down_block_res_samples[-1].shape[2:]if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:sample = upsample_block(hidden_states=sample,temb=emb,res_hidden_states_tuple=res_samples,encoder_hidden_states=encoder_hidden_states,cross_attention_kwargs=cross_attention_kwargs,upsample_size=upsample_size,attention_mask=attention_mask,)else:sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)# 6. post-processsample = self.conv_norm_out(sample)sample = self.conv_act(sample)sample = self.conv_out(sample)if not return_dict:return (sample,)return UNet2DConditionOutput(sample=sample)

也就是说在:down,mid和up Block时候都有传入text_embedding的信息encoder_hidden_states和cross attention的控制:cross_attention_kwargs.

具体每一个Block的实现看源码

相关文章:

Diffusion中的Unet (DIMP)

针对UNet2DConditionModel模型 查看Unet的源码,得知Unet的down,mid,up blocks的类型分别是: down_block_types: Tuple[str] ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2…...

编译以前项目更改在x64下面时报错:函数“PVOID GetCurrentFiber(void)”已有主体

win32下面编译成功,但是x64报错 1>GetWord.c 1>md5.c 这两个文件无法编译 1>C:\Program Files (x86)\Windows Kits\10\Include\10.0.22000.0\um\winnt.h(24125,1): error C2084: 函数“PVOID GetCurrentFiber(void)”已有主体 1>C:\Program Files (x…...

【AIGC】大模型面试高频考点-数据清洗篇

【AIGC】大模型面试高频考点-数据清洗篇 (一)常用文本清洗方法1.去除无用的符号2.去除表情符号3.文本只保留汉字4.中文繁体、简体转换5.删除 HTML 标签和特殊字符6.标记化7.小写8.停用词删除9.词干提取和词形还原10.处理缺失数据11.删除重复文本12.处理嘈…...

当测试时间与测试资源有限时,你会如何优化测试策略?

1.优先级排序:根据项目的需求和紧急程度进行优先级排序,将测试用例用例划分优先级,合理安排测试资源 和时间。这样能够保障在有限的时间内测试到最关键的功能 2.提前介入测试:在开发过程中提前进行测试,可以迅速发现问…...

基于R语言森林生态系统结构、功能与稳定性分析与可视化

在生态学研究中,森林生态系统的结构、功能与稳定性是核心研究内容之一。这些方面不仅关系到森林动态变化和物种多样性,还直接影响森林提供的生态服务功能及其应对环境变化的能力。森林生态系统的结构主要包括物种组成、树种多样性、树木的空间分布与密度…...

如何使用 Python 实现插件式架构

使用 Python 实现插件式架构可以通过动态加载和调用模块或类,构建一个易于扩展和维护的系统。以下是实现插件式架构的步骤和核心思想。 1. 插件式架构核心概念 主程序:负责加载、管理插件,并调用插件的功能。插件:独立的模块或类…...

【北京迅为】iTOP-4412全能版使用手册-第二十章 搭建和测试NFS服务器

iTOP-4412全能版采用四核Cortex-A9,主频为1.4GHz-1.6GHz,配备S5M8767 电源管理,集成USB HUB,选用高品质板对板连接器稳定可靠,大厂生产,做工精良。接口一应俱全,开发更简单,搭载全网通4G、支持WIFI、蓝牙、…...

【纯原生js】原生实现h5落地页面中的单选组件按钮及功能

h5端的按钮系统自带的一般都很丑&#xff0c;需要我们进行二次美化&#xff0c;比如单选按钮复选框之类的&#xff0c;那怎么对其进行html和css的改造&#xff1f; 实现效果 实现代码 <section id"tags"><h2>给景区添加标题</h2><label><…...

深入浅出:开发者如何快速上手Web3生态系统

Web3作为互联网的未来发展方向&#xff0c;正在逐步改变传统互联网架构&#xff0c;推动去中心化技术的发展。对于开发者而言&#xff0c;Web3代表着一个充满机遇与挑战的新领域&#xff0c;学习和掌握Web3的基本技术和工具&#xff0c;将为未来的项目开发提供强大的支持。那么…...

通过深度点图表示的隐式场实现肺树结构的高效解剖标注文献速递-生成式模型与transformer在医学影像中的应用

Title 题目 Efficient anatomical labeling of pulmonary tree structures via deeppoint-graph representation-based implicit fields 通过深度点图表示的隐式场实现肺树结构的高效解剖标注 01 文献速递介绍 近年来&#xff0c;肺部疾病&#xff08;Decramer等&#xff…...

数据结构 (17)广义表

前言 数据结构中的广义表&#xff08;Generalized List&#xff0c;又称列表Lists&#xff09;是一种重要的数据结构&#xff0c;它是对线性表的一种推广&#xff0c;放松了对表元素的原子限制&#xff0c;容许它们具有其自身的结构。 一、定义与表示 定义&#xff1a;广义表是…...

论文笔记 SliceGPT: Compress Large Language Models By Deleting Rows And Columns

欲买桂花同载酒&#xff0c;终不似&#xff0c;少年游。 数学知识 秩&#xff1a; 矩阵中最大线性无关的行/列向量数。行秩与列秩相等。 线性无关&#xff1a;对于N个向量而言&#xff0c;如果任取一个向量 v \textbf{v} v&#xff0c;不能被剩下的N-1个向量通过线性组合的方式…...

前端工具的选择和安装

选择和安装前端工具是前端开发过程中的重要步骤。现代前端开发需要一些工具来提高效率和协作能力。以下是一些常用的前端工具及其选择和安装指南。 1. 代码编辑器 选择一个好的代码编辑器可以显著提高开发效率。以下是几款流行的代码编辑器&#xff1a; Visual Studio Code (…...

Fantasy中定时器得驱动原理

一、服务器框架启动 public static async FTask Start(){// 启动ProcessStartProcess().Coroutine();await FTask.CompletedTask;while (true){ThreadScheduler.Update();Thread.Sleep(1);}} 二、主线程 Fantasy.ThreadScheduler.Update internal static void Update(){MainS…...

【反转链表】力扣 445. 两数相加 II

一、题目 二、思路 加法运算是从低位开始&#xff0c;向高位进位&#xff0c;因此需要将两个链表进行反转&#xff0c;再进行对齐后的相加操作。力扣 2. 两数相加 三、题解 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode …...

SpringBoot 项目中使用 spring-boot-starter-amqp 依赖实现 RabbitMQ

文章目录 前言1、application.yml2、RabbitMqConfig3、MqMessage4、MqMessageItem5、DirectMode6、StateConsumer&#xff1a;消费者7、InfoConsumer&#xff1a;消费者 前言 本文是工作之余的随手记&#xff0c;记录在工作期间使用 RabbitMQ 的笔记。 1、application.yml 使…...

Uniapp 安装安卓、IOS模拟器并调试

一、安装Android模拟器并调试 1. 下载并安装 Android Studio 首先下载 Mac 环境下的 Android Studio 的安装包&#xff0c;为dmg 格式。 下载完将Android Studio 向右拖拽到Applications中&#xff0c;接下来等待安装完成就OK啦&#xff01; 打开过程界面如下图所示&#xf…...

JavaScript 中的原型和原型链

JavaScript 中的原型和原型链也是一个相对较难理解透彻的知识点&#xff0c;下面结合详细例子来进行说明&#xff1a; 一、原型的概念 在 JavaScript 中&#xff0c;每个函数都有一个 prototype 属性&#xff0c;这个属性指向一个对象&#xff0c;这个对象就是所谓的 “原型对…...

数组变换(两倍)

数组变换 以最大元素为基准元素&#xff0c;判读其他元素能否通过 x 2 成为最大值&#xff01; 那么怎么判断呢&#xff1a; max % arr[i] 0arr[i] * 2 ^n max int x 2 ^ n max / arr[i] 3.只需判断 这个 x 是不是 2 的 n 次放就可以了&#xff01; 判断 是否为 2 的 n 次 …...

GBN协议、SR协议

1、回退N步&#xff08;Go-Back-N,GBN&#xff09;协议&#xff1a; 总结&#xff1a; GBN协议的特点&#xff1a; &#xff08;1&#xff09;累计确认机制&#xff1a;当发送方收到ACKn时&#xff0c;表明接收方已正确接收序号为n以及序号小于n的所有分组&#xff0c;发送窗…...

三维扫描检测仪3d扫描测量尺寸-自动蓝光测量

在现代工业及生产过程中&#xff0c;精确、高效的尺寸检测是保证产品质量、提升生产效率的关键因素。 红、蓝光测量&#xff0c;以其高精度、高效率和非接触式的特点&#xff0c;在工业及生产中发挥着越来越重要的作用。蓝光测量技术利用蓝色激光光源&#xff0c;通过扫描被测…...

大模型翻译能力评测

1. 背景介绍 随着自然语言处理技术的飞速发展&#xff0c;机器翻译已经成为一个重要的研究领域。近年来&#xff0c;基于大模型的语言模型在机器翻译任务上取得了显著的进展。这些大模型通常具有数亿甚至数千亿的参数&#xff0c;能够更好地理解和生成自然语言。 但是&#xf…...

MySQL隐式转换造成索引失效

一、什么是 MySQL 的隐式转换&#xff1f; MySQL 在执行查询语句时&#xff0c;有时候会自动帮我们进行数据类型的转换&#xff0c;这个过程就是隐式转换。比如说&#xff0c;我们在一个 INT 类型的字段上进行查询&#xff0c;但是传入的查询条件却是字符串类型的值&#xff0c…...

SuperMap Objects组件式GIS开发技术浅析

引言 随着GIS应用领域的扩展&#xff0c;GIS开发工作日显重要。一般地&#xff0c;从平台和模式上划分&#xff0c;GIS二次开发主要有三种实现方式&#xff1a;独立开发、单纯二次开发和集成二次开发。上述的GIS应用开发方式各有利弊&#xff0c;其中集成二次开发既可以充分利…...

多组数输入a+b:JAVA

链接&#xff1a;登录—专业IT笔试面试备考平台_牛客网 来源&#xff1a;牛客网 输入描述: 输入包含多组数据&#xff0c;每组数据输入一行&#xff0c;包含两个整数 输出描述: 对于每组数据输出一行包含一个整数表示两个整数的和 代码: import java.util.Scanner; pu…...

R语言结构方程模型(SEM)在生态学领域中的应用

目录 专题一、R/Rstudio简介及入门 专题二、结构方程模型&#xff08;SEM&#xff09;介绍 专题三&#xff1a;R语言SEM分析入门&#xff1a;lavaan VS piecewiseSEM 专题四&#xff1a;SEM全局估计&#xff08;lavaan&#xff09;在生态学领域高阶应用 专题五&#xff1…...

架构-微服务-服务调用Dubbo

文章目录 前言一、Dubbo介绍1. 什么是Dubbo 二、实现1. 提供统一业务api2. 提供服务提供者3. 提供服务消费者 前言 服务调用方案--Dubbo‌ 基于 Java 的高性能 RPC分布式服务框架&#xff0c;致力于提供高性能和透明化的 RPC远程服务调用方案&#xff0c;以及SOA服务治理方案。…...

【SpringBoot问题】IDEA中用Service窗口展示所有服务及端口的办法

1、调出Service窗口 打开View→Tool Windows→Service&#xff0c;即可显示。 2、正常情况应该已经出现SpringBoot&#xff0c;如下图请继续第三步 3、配置Service窗口的项目启动类型。微服务一般是Springboot类型。所以这里需要选择一下。 点击最后一个号&#xff0c;点击Ru…...

OpenCV 图像轮廓查找与绘制全攻略:从函数使用到实战应用详解

摘要&#xff1a;本文详细介绍了 OpenCV 中用于查找图像轮廓的 cv2.findContours() 函数以及绘制轮廓的 cv2.drawContours() 函数的使用方法。涵盖 cv2.findContours() 各参数&#xff08;如 mode 不同取值对应不同轮廓检索模式&#xff09;及返回值的详细解析&#xff0c;搭配…...

电机驱动MCU介绍

电机驱动MCU是一种专为电机控制设计的微控制器单元&#xff0c;它集成了先进的控制算法和高性能的功率输出能力。 电机驱动MCU采用高性能的处理器核心&#xff0c;具有快速的运算速度和丰富的外设接口。它内置了专业的电机控制算法&#xff0c;包括PID控制、FOC&#xff08;Fi…...

iis7搭建网站织梦/淘宝指数网址

一、选择交换机的主要技能指标 交换机&#xff1a;交换机&#xff08;Switch&#xff09;意为“开关”是一种用于电&#xff08;光&#xff09;信号转发的网络设备。它可以为接入交换机的任意两个网络节点提供独享的电信号通路。最常见的交换机是以太网交换机。其他常见的还有电…...

网站编辑怎么做内容分类/谷歌浏览器官网下载安装

我们在日常开发中少不了和JSON数据打交道&#xff0c;那么我们来看看JAVA中常用的JSON解析方式。1、JSON官方2、GSON3、FastJSON4、jacksonJSON操作涉及到的类&#xff1a;public class Student {private int id;private String name;private int age;public int getId() {retu…...

长沙网站建设规划/合肥网

1,报错提示: 编辑器或项目正在尝试签出在内存中修改的文件&#xff0c;这将导致保存该文件。 在生成过程中保存文件是危险的&#xff0c;这可能会在将来导致不正确的生成输出。 是否仍然继续签出? 2,原因:licenses.licx属性设为了只读. 3,解决: a,搜索licenses.licx,去掉只读属…...

内容不相关的网站做301重定向/免费的短视频app大全

网络协议的定义&#xff1a;为计算机网络中进行数据交换而建立的规则、标准或约定的集合。例如&#xff0c;网络中一个微机用户和一个大型主机的操作员进行通信&#xff0c;由于这两个数据终端所用字符集不同&#xff0c;因此操作员所输入的命令彼此不认识。为了能进行通信&…...

开发网站开票名称是什么/如何投放网络广告

3、评测平台介绍及方法说明AMD FM1(APU)平台CPU AMD A6-3650(4核/4线程)主板 华硕 F1A75-M PRO(A75)内存 宇瞻 DDR3-1600 2G x 2(8-8-8-24)硬盘 日立 1TB显卡 Radeon HD 6530D(APU内置)Radeon HD 6670 双显卡交火Radeon HD 6570 双显卡交火Intel LGA1155平台CPU Intel Core i3 …...

wordpress图片美化/郑州seo学校

使用tcp通讯, 1 实现连接服务器 2 收发数据并显示 下载地址:http://download.csdn.net/download/taoerit/9964309 Tcp通讯详解:http://blog.csdn.net/taoerit/article/details/77598564 效果图: <...