Diffusion中的Unet (DIMP)
针对UNet2DConditionModel模型
查看Unet的源码,得知Unet的down,mid,up blocks的类型分别是:
down_block_types: Tuple[str] = ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2D",),mid_block_type: str = "UNetMidBlock2DCrossAttn",up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
查看一下down 下采样的get_down_block方法:
def get_down_block(down_block_type,num_layers,in_channels,out_channels,temb_channels,add_downsample,resnet_eps,resnet_act_fn,attn_num_head_channels,resnet_groups=None,cross_attention_dim=None,downsample_padding=None,dual_cross_attention=False,use_linear_projection=False,only_cross_attention=False,upcast_attention=False,resnet_time_scale_shift="default",
):down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_typeif down_block_type == "DownBlock2D":return DownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "ResnetDownsampleBlock2D":return ResnetDownsampleBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownBlock2D":return AttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "CrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")return CrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention,upcast_attention=upcast_attention,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SimpleCrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")return SimpleCrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SkipDownBlock2D":return SkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnSkipDownBlock2D":return AttnSkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "DownEncoderBlock2D":return DownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownEncoderBlock2D":return AttnDownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)raise ValueError(f"{down_block_type} does not exist.")
我们看一下该Unet的forward函数:
def forward(self,sample: torch.FloatTensor,timestep: Union[torch.Tensor, float, int],encoder_hidden_states: torch.Tensor,class_labels: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,cross_attention_kwargs: Optional[Dict[str, Any]] = None,return_dict: bool = True,) -> Union[UNet2DConditionOutput, Tuple]:r"""Args:sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensortimestep (`torch.FloatTensor` or `float` or `int`): (batch) timestepsencoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden statesreturn_dict (`bool`, *optional*, defaults to `True`):Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.Returns:[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. Whenreturning a tuple, the first element is the sample tensor."""# By default samples have to be AT least a multiple of the overall upsampling factor.# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).# However, the upsampling interpolation output size can be forced to fit any upsampling size# on the fly if necessary.default_overall_up_factor = 2**self.num_upsamplers# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`forward_upsample_size = Falseupsample_size = Noneif any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):logger.info("Forward upsample size to force interpolation output size.")forward_upsample_size = True# prepare attention_maskif attention_mask is not None:attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0attention_mask = attention_mask.unsqueeze(1)# 0. center input if necessaryif self.config.center_input_sample:sample = 2 * sample - 1.0# 1. timetimesteps = timestepif not torch.is_tensor(timesteps):# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can# This would be a good case for the `match` statement (Python 3.10+)is_mps = sample.device.type == "mps"if isinstance(timestep, float):dtype = torch.float32 if is_mps else torch.float64else:dtype = torch.int32 if is_mps else torch.int64timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)elif len(timesteps.shape) == 0:timesteps = timesteps[None].to(sample.device)# broadcast to batch dimension in a way that's compatible with ONNX/Core MLtimesteps = timesteps.expand(sample.shape[0])t_emb = self.time_proj(timesteps)# timesteps does not contain any weights and will always return f32 tensors# but time_embedding might actually be running in fp16. so we need to cast here.# there might be better ways to encapsulate this.t_emb = t_emb.to(dtype=self.dtype)emb = self.time_embedding(t_emb)if self.class_embedding is not None:if class_labels is None:raise ValueError("class_labels should be provided when num_class_embeds > 0")if self.config.class_embed_type == "timestep":class_labels = self.time_proj(class_labels)class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)emb = emb + class_emb# 2. pre-processsample = self.conv_in(sample)# 3. downdown_block_res_samples = (sample,)for downsample_block in self.down_blocks:if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:sample, res_samples = downsample_block(hidden_states=sample,temb=emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)else:sample, res_samples = downsample_block(hidden_states=sample, temb=emb)down_block_res_samples += res_samples# 4. midsample = self.mid_block(sample,emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)# 5. upfor i, upsample_block in enumerate(self.up_blocks):is_final_block = i == len(self.up_blocks) - 1res_samples = down_block_res_samples[-len(upsample_block.resnets) :]down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]# if we have not reached the final block and need to forward the# upsample size, we do it hereif not is_final_block and forward_upsample_size:upsample_size = down_block_res_samples[-1].shape[2:]if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:sample = upsample_block(hidden_states=sample,temb=emb,res_hidden_states_tuple=res_samples,encoder_hidden_states=encoder_hidden_states,cross_attention_kwargs=cross_attention_kwargs,upsample_size=upsample_size,attention_mask=attention_mask,)else:sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)# 6. post-processsample = self.conv_norm_out(sample)sample = self.conv_act(sample)sample = self.conv_out(sample)if not return_dict:return (sample,)return UNet2DConditionOutput(sample=sample)
也就是说在:down,mid和up Block时候都有传入text_embedding的信息encoder_hidden_states和cross attention的控制:cross_attention_kwargs.
具体每一个Block的实现看源码
相关文章:
Diffusion中的Unet (DIMP)
针对UNet2DConditionModel模型 查看Unet的源码,得知Unet的down,mid,up blocks的类型分别是: down_block_types: Tuple[str] ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2…...
编译以前项目更改在x64下面时报错:函数“PVOID GetCurrentFiber(void)”已有主体
win32下面编译成功,但是x64报错 1>GetWord.c 1>md5.c 这两个文件无法编译 1>C:\Program Files (x86)\Windows Kits\10\Include\10.0.22000.0\um\winnt.h(24125,1): error C2084: 函数“PVOID GetCurrentFiber(void)”已有主体 1>C:\Program Files (x…...
【AIGC】大模型面试高频考点-数据清洗篇
【AIGC】大模型面试高频考点-数据清洗篇 (一)常用文本清洗方法1.去除无用的符号2.去除表情符号3.文本只保留汉字4.中文繁体、简体转换5.删除 HTML 标签和特殊字符6.标记化7.小写8.停用词删除9.词干提取和词形还原10.处理缺失数据11.删除重复文本12.处理嘈…...
当测试时间与测试资源有限时,你会如何优化测试策略?
1.优先级排序:根据项目的需求和紧急程度进行优先级排序,将测试用例用例划分优先级,合理安排测试资源 和时间。这样能够保障在有限的时间内测试到最关键的功能 2.提前介入测试:在开发过程中提前进行测试,可以迅速发现问…...
基于R语言森林生态系统结构、功能与稳定性分析与可视化
在生态学研究中,森林生态系统的结构、功能与稳定性是核心研究内容之一。这些方面不仅关系到森林动态变化和物种多样性,还直接影响森林提供的生态服务功能及其应对环境变化的能力。森林生态系统的结构主要包括物种组成、树种多样性、树木的空间分布与密度…...
如何使用 Python 实现插件式架构
使用 Python 实现插件式架构可以通过动态加载和调用模块或类,构建一个易于扩展和维护的系统。以下是实现插件式架构的步骤和核心思想。 1. 插件式架构核心概念 主程序:负责加载、管理插件,并调用插件的功能。插件:独立的模块或类…...
【北京迅为】iTOP-4412全能版使用手册-第二十章 搭建和测试NFS服务器
iTOP-4412全能版采用四核Cortex-A9,主频为1.4GHz-1.6GHz,配备S5M8767 电源管理,集成USB HUB,选用高品质板对板连接器稳定可靠,大厂生产,做工精良。接口一应俱全,开发更简单,搭载全网通4G、支持WIFI、蓝牙、…...
【纯原生js】原生实现h5落地页面中的单选组件按钮及功能
h5端的按钮系统自带的一般都很丑,需要我们进行二次美化,比如单选按钮复选框之类的,那怎么对其进行html和css的改造? 实现效果 实现代码 <section id"tags"><h2>给景区添加标题</h2><label><…...
深入浅出:开发者如何快速上手Web3生态系统
Web3作为互联网的未来发展方向,正在逐步改变传统互联网架构,推动去中心化技术的发展。对于开发者而言,Web3代表着一个充满机遇与挑战的新领域,学习和掌握Web3的基本技术和工具,将为未来的项目开发提供强大的支持。那么…...
通过深度点图表示的隐式场实现肺树结构的高效解剖标注文献速递-生成式模型与transformer在医学影像中的应用
Title 题目 Efficient anatomical labeling of pulmonary tree structures via deeppoint-graph representation-based implicit fields 通过深度点图表示的隐式场实现肺树结构的高效解剖标注 01 文献速递介绍 近年来,肺部疾病(Decramer等ÿ…...
数据结构 (17)广义表
前言 数据结构中的广义表(Generalized List,又称列表Lists)是一种重要的数据结构,它是对线性表的一种推广,放松了对表元素的原子限制,容许它们具有其自身的结构。 一、定义与表示 定义:广义表是…...
论文笔记 SliceGPT: Compress Large Language Models By Deleting Rows And Columns
欲买桂花同载酒,终不似,少年游。 数学知识 秩: 矩阵中最大线性无关的行/列向量数。行秩与列秩相等。 线性无关:对于N个向量而言,如果任取一个向量 v \textbf{v} v,不能被剩下的N-1个向量通过线性组合的方式…...
前端工具的选择和安装
选择和安装前端工具是前端开发过程中的重要步骤。现代前端开发需要一些工具来提高效率和协作能力。以下是一些常用的前端工具及其选择和安装指南。 1. 代码编辑器 选择一个好的代码编辑器可以显著提高开发效率。以下是几款流行的代码编辑器: Visual Studio Code (…...
Fantasy中定时器得驱动原理
一、服务器框架启动 public static async FTask Start(){// 启动ProcessStartProcess().Coroutine();await FTask.CompletedTask;while (true){ThreadScheduler.Update();Thread.Sleep(1);}} 二、主线程 Fantasy.ThreadScheduler.Update internal static void Update(){MainS…...
【反转链表】力扣 445. 两数相加 II
一、题目 二、思路 加法运算是从低位开始,向高位进位,因此需要将两个链表进行反转,再进行对齐后的相加操作。力扣 2. 两数相加 三、题解 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode …...
SpringBoot 项目中使用 spring-boot-starter-amqp 依赖实现 RabbitMQ
文章目录 前言1、application.yml2、RabbitMqConfig3、MqMessage4、MqMessageItem5、DirectMode6、StateConsumer:消费者7、InfoConsumer:消费者 前言 本文是工作之余的随手记,记录在工作期间使用 RabbitMQ 的笔记。 1、application.yml 使…...
Uniapp 安装安卓、IOS模拟器并调试
一、安装Android模拟器并调试 1. 下载并安装 Android Studio 首先下载 Mac 环境下的 Android Studio 的安装包,为dmg 格式。 下载完将Android Studio 向右拖拽到Applications中,接下来等待安装完成就OK啦! 打开过程界面如下图所示…...
JavaScript 中的原型和原型链
JavaScript 中的原型和原型链也是一个相对较难理解透彻的知识点,下面结合详细例子来进行说明: 一、原型的概念 在 JavaScript 中,每个函数都有一个 prototype 属性,这个属性指向一个对象,这个对象就是所谓的 “原型对…...
数组变换(两倍)
数组变换 以最大元素为基准元素,判读其他元素能否通过 x 2 成为最大值! 那么怎么判断呢: max % arr[i] 0arr[i] * 2 ^n max int x 2 ^ n max / arr[i] 3.只需判断 这个 x 是不是 2 的 n 次放就可以了! 判断 是否为 2 的 n 次 …...
GBN协议、SR协议
1、回退N步(Go-Back-N,GBN)协议: 总结: GBN协议的特点: (1)累计确认机制:当发送方收到ACKn时,表明接收方已正确接收序号为n以及序号小于n的所有分组,发送窗…...
三维扫描检测仪3d扫描测量尺寸-自动蓝光测量
在现代工业及生产过程中,精确、高效的尺寸检测是保证产品质量、提升生产效率的关键因素。 红、蓝光测量,以其高精度、高效率和非接触式的特点,在工业及生产中发挥着越来越重要的作用。蓝光测量技术利用蓝色激光光源,通过扫描被测…...
大模型翻译能力评测
1. 背景介绍 随着自然语言处理技术的飞速发展,机器翻译已经成为一个重要的研究领域。近年来,基于大模型的语言模型在机器翻译任务上取得了显著的进展。这些大模型通常具有数亿甚至数千亿的参数,能够更好地理解和生成自然语言。 但是…...
MySQL隐式转换造成索引失效
一、什么是 MySQL 的隐式转换? MySQL 在执行查询语句时,有时候会自动帮我们进行数据类型的转换,这个过程就是隐式转换。比如说,我们在一个 INT 类型的字段上进行查询,但是传入的查询条件却是字符串类型的值,…...
SuperMap Objects组件式GIS开发技术浅析
引言 随着GIS应用领域的扩展,GIS开发工作日显重要。一般地,从平台和模式上划分,GIS二次开发主要有三种实现方式:独立开发、单纯二次开发和集成二次开发。上述的GIS应用开发方式各有利弊,其中集成二次开发既可以充分利…...
多组数输入a+b:JAVA
链接:登录—专业IT笔试面试备考平台_牛客网 来源:牛客网 输入描述: 输入包含多组数据,每组数据输入一行,包含两个整数 输出描述: 对于每组数据输出一行包含一个整数表示两个整数的和 代码: import java.util.Scanner; pu…...
R语言结构方程模型(SEM)在生态学领域中的应用
目录 专题一、R/Rstudio简介及入门 专题二、结构方程模型(SEM)介绍 专题三:R语言SEM分析入门:lavaan VS piecewiseSEM 专题四:SEM全局估计(lavaan)在生态学领域高阶应用 专题五࿱…...
架构-微服务-服务调用Dubbo
文章目录 前言一、Dubbo介绍1. 什么是Dubbo 二、实现1. 提供统一业务api2. 提供服务提供者3. 提供服务消费者 前言 服务调用方案--Dubbo 基于 Java 的高性能 RPC分布式服务框架,致力于提供高性能和透明化的 RPC远程服务调用方案,以及SOA服务治理方案。…...
【SpringBoot问题】IDEA中用Service窗口展示所有服务及端口的办法
1、调出Service窗口 打开View→Tool Windows→Service,即可显示。 2、正常情况应该已经出现SpringBoot,如下图请继续第三步 3、配置Service窗口的项目启动类型。微服务一般是Springboot类型。所以这里需要选择一下。 点击最后一个号,点击Ru…...
OpenCV 图像轮廓查找与绘制全攻略:从函数使用到实战应用详解
摘要:本文详细介绍了 OpenCV 中用于查找图像轮廓的 cv2.findContours() 函数以及绘制轮廓的 cv2.drawContours() 函数的使用方法。涵盖 cv2.findContours() 各参数(如 mode 不同取值对应不同轮廓检索模式)及返回值的详细解析,搭配…...
电机驱动MCU介绍
电机驱动MCU是一种专为电机控制设计的微控制器单元,它集成了先进的控制算法和高性能的功率输出能力。 电机驱动MCU采用高性能的处理器核心,具有快速的运算速度和丰富的外设接口。它内置了专业的电机控制算法,包括PID控制、FOC(Fi…...
iis7搭建网站织梦/淘宝指数网址
一、选择交换机的主要技能指标 交换机:交换机(Switch)意为“开关”是一种用于电(光)信号转发的网络设备。它可以为接入交换机的任意两个网络节点提供独享的电信号通路。最常见的交换机是以太网交换机。其他常见的还有电…...
网站编辑怎么做内容分类/谷歌浏览器官网下载安装
我们在日常开发中少不了和JSON数据打交道,那么我们来看看JAVA中常用的JSON解析方式。1、JSON官方2、GSON3、FastJSON4、jacksonJSON操作涉及到的类:public class Student {private int id;private String name;private int age;public int getId() {retu…...
长沙网站建设规划/合肥网
1,报错提示: 编辑器或项目正在尝试签出在内存中修改的文件,这将导致保存该文件。 在生成过程中保存文件是危险的,这可能会在将来导致不正确的生成输出。 是否仍然继续签出? 2,原因:licenses.licx属性设为了只读. 3,解决: a,搜索licenses.licx,去掉只读属…...
内容不相关的网站做301重定向/免费的短视频app大全
网络协议的定义:为计算机网络中进行数据交换而建立的规则、标准或约定的集合。例如,网络中一个微机用户和一个大型主机的操作员进行通信,由于这两个数据终端所用字符集不同,因此操作员所输入的命令彼此不认识。为了能进行通信&…...
开发网站开票名称是什么/如何投放网络广告
3、评测平台介绍及方法说明AMD FM1(APU)平台CPU AMD A6-3650(4核/4线程)主板 华硕 F1A75-M PRO(A75)内存 宇瞻 DDR3-1600 2G x 2(8-8-8-24)硬盘 日立 1TB显卡 Radeon HD 6530D(APU内置)Radeon HD 6670 双显卡交火Radeon HD 6570 双显卡交火Intel LGA1155平台CPU Intel Core i3 …...
wordpress图片美化/郑州seo学校
使用tcp通讯, 1 实现连接服务器 2 收发数据并显示 下载地址:http://download.csdn.net/download/taoerit/9964309 Tcp通讯详解:http://blog.csdn.net/taoerit/article/details/77598564 效果图: <...