Diffusion中的Unet (DIMP)
针对UNet2DConditionModel模型
查看Unet的源码,得知Unet的down,mid,up blocks的类型分别是:
down_block_types: Tuple[str] = ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2D",),mid_block_type: str = "UNetMidBlock2DCrossAttn",up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
查看一下down 下采样的get_down_block方法:
def get_down_block(down_block_type,num_layers,in_channels,out_channels,temb_channels,add_downsample,resnet_eps,resnet_act_fn,attn_num_head_channels,resnet_groups=None,cross_attention_dim=None,downsample_padding=None,dual_cross_attention=False,use_linear_projection=False,only_cross_attention=False,upcast_attention=False,resnet_time_scale_shift="default",
):down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_typeif down_block_type == "DownBlock2D":return DownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "ResnetDownsampleBlock2D":return ResnetDownsampleBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownBlock2D":return AttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "CrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")return CrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention,upcast_attention=upcast_attention,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SimpleCrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")return SimpleCrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SkipDownBlock2D":return SkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnSkipDownBlock2D":return AttnSkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "DownEncoderBlock2D":return DownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownEncoderBlock2D":return AttnDownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)raise ValueError(f"{down_block_type} does not exist.")
我们看一下该Unet的forward函数:
def forward(self,sample: torch.FloatTensor,timestep: Union[torch.Tensor, float, int],encoder_hidden_states: torch.Tensor,class_labels: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,cross_attention_kwargs: Optional[Dict[str, Any]] = None,return_dict: bool = True,) -> Union[UNet2DConditionOutput, Tuple]:r"""Args:sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensortimestep (`torch.FloatTensor` or `float` or `int`): (batch) timestepsencoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden statesreturn_dict (`bool`, *optional*, defaults to `True`):Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.Returns:[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. Whenreturning a tuple, the first element is the sample tensor."""# By default samples have to be AT least a multiple of the overall upsampling factor.# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).# However, the upsampling interpolation output size can be forced to fit any upsampling size# on the fly if necessary.default_overall_up_factor = 2**self.num_upsamplers# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`forward_upsample_size = Falseupsample_size = Noneif any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):logger.info("Forward upsample size to force interpolation output size.")forward_upsample_size = True# prepare attention_maskif attention_mask is not None:attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0attention_mask = attention_mask.unsqueeze(1)# 0. center input if necessaryif self.config.center_input_sample:sample = 2 * sample - 1.0# 1. timetimesteps = timestepif not torch.is_tensor(timesteps):# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can# This would be a good case for the `match` statement (Python 3.10+)is_mps = sample.device.type == "mps"if isinstance(timestep, float):dtype = torch.float32 if is_mps else torch.float64else:dtype = torch.int32 if is_mps else torch.int64timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)elif len(timesteps.shape) == 0:timesteps = timesteps[None].to(sample.device)# broadcast to batch dimension in a way that's compatible with ONNX/Core MLtimesteps = timesteps.expand(sample.shape[0])t_emb = self.time_proj(timesteps)# timesteps does not contain any weights and will always return f32 tensors# but time_embedding might actually be running in fp16. so we need to cast here.# there might be better ways to encapsulate this.t_emb = t_emb.to(dtype=self.dtype)emb = self.time_embedding(t_emb)if self.class_embedding is not None:if class_labels is None:raise ValueError("class_labels should be provided when num_class_embeds > 0")if self.config.class_embed_type == "timestep":class_labels = self.time_proj(class_labels)class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)emb = emb + class_emb# 2. pre-processsample = self.conv_in(sample)# 3. downdown_block_res_samples = (sample,)for downsample_block in self.down_blocks:if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:sample, res_samples = downsample_block(hidden_states=sample,temb=emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)else:sample, res_samples = downsample_block(hidden_states=sample, temb=emb)down_block_res_samples += res_samples# 4. midsample = self.mid_block(sample,emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)# 5. upfor i, upsample_block in enumerate(self.up_blocks):is_final_block = i == len(self.up_blocks) - 1res_samples = down_block_res_samples[-len(upsample_block.resnets) :]down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]# if we have not reached the final block and need to forward the# upsample size, we do it hereif not is_final_block and forward_upsample_size:upsample_size = down_block_res_samples[-1].shape[2:]if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:sample = upsample_block(hidden_states=sample,temb=emb,res_hidden_states_tuple=res_samples,encoder_hidden_states=encoder_hidden_states,cross_attention_kwargs=cross_attention_kwargs,upsample_size=upsample_size,attention_mask=attention_mask,)else:sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)# 6. post-processsample = self.conv_norm_out(sample)sample = self.conv_act(sample)sample = self.conv_out(sample)if not return_dict:return (sample,)return UNet2DConditionOutput(sample=sample)
也就是说在:down,mid和up Block时候都有传入text_embedding的信息encoder_hidden_states和cross attention的控制:cross_attention_kwargs.
具体每一个Block的实现看源码
相关文章:
Diffusion中的Unet (DIMP)
针对UNet2DConditionModel模型 查看Unet的源码,得知Unet的down,mid,up blocks的类型分别是: down_block_types: Tuple[str] ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2…...
编译以前项目更改在x64下面时报错:函数“PVOID GetCurrentFiber(void)”已有主体
win32下面编译成功,但是x64报错 1>GetWord.c 1>md5.c 这两个文件无法编译 1>C:\Program Files (x86)\Windows Kits\10\Include\10.0.22000.0\um\winnt.h(24125,1): error C2084: 函数“PVOID GetCurrentFiber(void)”已有主体 1>C:\Program Files (x…...
【AIGC】大模型面试高频考点-数据清洗篇
【AIGC】大模型面试高频考点-数据清洗篇 (一)常用文本清洗方法1.去除无用的符号2.去除表情符号3.文本只保留汉字4.中文繁体、简体转换5.删除 HTML 标签和特殊字符6.标记化7.小写8.停用词删除9.词干提取和词形还原10.处理缺失数据11.删除重复文本12.处理嘈…...
当测试时间与测试资源有限时,你会如何优化测试策略?
1.优先级排序:根据项目的需求和紧急程度进行优先级排序,将测试用例用例划分优先级,合理安排测试资源 和时间。这样能够保障在有限的时间内测试到最关键的功能 2.提前介入测试:在开发过程中提前进行测试,可以迅速发现问…...
基于R语言森林生态系统结构、功能与稳定性分析与可视化
在生态学研究中,森林生态系统的结构、功能与稳定性是核心研究内容之一。这些方面不仅关系到森林动态变化和物种多样性,还直接影响森林提供的生态服务功能及其应对环境变化的能力。森林生态系统的结构主要包括物种组成、树种多样性、树木的空间分布与密度…...
如何使用 Python 实现插件式架构
使用 Python 实现插件式架构可以通过动态加载和调用模块或类,构建一个易于扩展和维护的系统。以下是实现插件式架构的步骤和核心思想。 1. 插件式架构核心概念 主程序:负责加载、管理插件,并调用插件的功能。插件:独立的模块或类…...
【北京迅为】iTOP-4412全能版使用手册-第二十章 搭建和测试NFS服务器
iTOP-4412全能版采用四核Cortex-A9,主频为1.4GHz-1.6GHz,配备S5M8767 电源管理,集成USB HUB,选用高品质板对板连接器稳定可靠,大厂生产,做工精良。接口一应俱全,开发更简单,搭载全网通4G、支持WIFI、蓝牙、…...
【纯原生js】原生实现h5落地页面中的单选组件按钮及功能
h5端的按钮系统自带的一般都很丑,需要我们进行二次美化,比如单选按钮复选框之类的,那怎么对其进行html和css的改造? 实现效果 实现代码 <section id"tags"><h2>给景区添加标题</h2><label><…...
深入浅出:开发者如何快速上手Web3生态系统
Web3作为互联网的未来发展方向,正在逐步改变传统互联网架构,推动去中心化技术的发展。对于开发者而言,Web3代表着一个充满机遇与挑战的新领域,学习和掌握Web3的基本技术和工具,将为未来的项目开发提供强大的支持。那么…...
通过深度点图表示的隐式场实现肺树结构的高效解剖标注文献速递-生成式模型与transformer在医学影像中的应用
Title 题目 Efficient anatomical labeling of pulmonary tree structures via deeppoint-graph representation-based implicit fields 通过深度点图表示的隐式场实现肺树结构的高效解剖标注 01 文献速递介绍 近年来,肺部疾病(Decramer等ÿ…...
数据结构 (17)广义表
前言 数据结构中的广义表(Generalized List,又称列表Lists)是一种重要的数据结构,它是对线性表的一种推广,放松了对表元素的原子限制,容许它们具有其自身的结构。 一、定义与表示 定义:广义表是…...
论文笔记 SliceGPT: Compress Large Language Models By Deleting Rows And Columns
欲买桂花同载酒,终不似,少年游。 数学知识 秩: 矩阵中最大线性无关的行/列向量数。行秩与列秩相等。 线性无关:对于N个向量而言,如果任取一个向量 v \textbf{v} v,不能被剩下的N-1个向量通过线性组合的方式…...
前端工具的选择和安装
选择和安装前端工具是前端开发过程中的重要步骤。现代前端开发需要一些工具来提高效率和协作能力。以下是一些常用的前端工具及其选择和安装指南。 1. 代码编辑器 选择一个好的代码编辑器可以显著提高开发效率。以下是几款流行的代码编辑器: Visual Studio Code (…...
Fantasy中定时器得驱动原理
一、服务器框架启动 public static async FTask Start(){// 启动ProcessStartProcess().Coroutine();await FTask.CompletedTask;while (true){ThreadScheduler.Update();Thread.Sleep(1);}} 二、主线程 Fantasy.ThreadScheduler.Update internal static void Update(){MainS…...
【反转链表】力扣 445. 两数相加 II
一、题目 二、思路 加法运算是从低位开始,向高位进位,因此需要将两个链表进行反转,再进行对齐后的相加操作。力扣 2. 两数相加 三、题解 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode …...
SpringBoot 项目中使用 spring-boot-starter-amqp 依赖实现 RabbitMQ
文章目录 前言1、application.yml2、RabbitMqConfig3、MqMessage4、MqMessageItem5、DirectMode6、StateConsumer:消费者7、InfoConsumer:消费者 前言 本文是工作之余的随手记,记录在工作期间使用 RabbitMQ 的笔记。 1、application.yml 使…...
Uniapp 安装安卓、IOS模拟器并调试
一、安装Android模拟器并调试 1. 下载并安装 Android Studio 首先下载 Mac 环境下的 Android Studio 的安装包,为dmg 格式。 下载完将Android Studio 向右拖拽到Applications中,接下来等待安装完成就OK啦! 打开过程界面如下图所示…...
JavaScript 中的原型和原型链
JavaScript 中的原型和原型链也是一个相对较难理解透彻的知识点,下面结合详细例子来进行说明: 一、原型的概念 在 JavaScript 中,每个函数都有一个 prototype 属性,这个属性指向一个对象,这个对象就是所谓的 “原型对…...
数组变换(两倍)
数组变换 以最大元素为基准元素,判读其他元素能否通过 x 2 成为最大值! 那么怎么判断呢: max % arr[i] 0arr[i] * 2 ^n max int x 2 ^ n max / arr[i] 3.只需判断 这个 x 是不是 2 的 n 次放就可以了! 判断 是否为 2 的 n 次 …...
GBN协议、SR协议
1、回退N步(Go-Back-N,GBN)协议: 总结: GBN协议的特点: (1)累计确认机制:当发送方收到ACKn时,表明接收方已正确接收序号为n以及序号小于n的所有分组,发送窗…...
多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度
一、引言:多云环境的技术复杂性本质 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时,基础设施的技术债呈现指数级积累。网络连接、身份认证、成本管理这三大核心挑战相互嵌套:跨云网络构建数据…...
conda相比python好处
Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理:…...
家政维修平台实战20:权限设计
目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系,主要是分成几个表,用户表我们是记录用户的基础信息,包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题,不同的角色…...
HBuilderX安装(uni-app和小程序开发)
下载HBuilderX 访问官方网站:https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本: Windows版(推荐下载标准版) Windows系统安装步骤 运行安装程序: 双击下载的.exe安装文件 如果出现安全提示&…...
在WSL2的Ubuntu镜像中安装Docker
Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包: for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...
中医有效性探讨
文章目录 西医是如何发展到以生物化学为药理基础的现代医学?传统医学奠基期(远古 - 17 世纪)近代医学转型期(17 世纪 - 19 世纪末)现代医学成熟期(20世纪至今) 中医的源远流长和一脉相承远古至…...
算法岗面试经验分享-大模型篇
文章目录 A 基础语言模型A.1 TransformerA.2 Bert B 大语言模型结构B.1 GPTB.2 LLamaB.3 ChatGLMB.4 Qwen C 大语言模型微调C.1 Fine-tuningC.2 Adapter-tuningC.3 Prefix-tuningC.4 P-tuningC.5 LoRA A 基础语言模型 A.1 Transformer (1)资源 论文&a…...
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材)
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材) 这个项目能干嘛? 使用 gemini 2.0 的 api 和 google 其他的 api 来做衍生处理 简化和优化了文生图和图生图的行为(我的最主要) 并且有一些目标检测和切割(我用不到) 视频和 imagefx 因为没 a…...
使用Spring AI和MCP协议构建图片搜索服务
目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式(本地调用) SSE模式(远程调用) 4. 注册工具提…...
使用LangGraph和LangSmith构建多智能体人工智能系统
现在,通过组合几个较小的子智能体来创建一个强大的人工智能智能体正成为一种趋势。但这也带来了一些挑战,比如减少幻觉、管理对话流程、在测试期间留意智能体的工作方式、允许人工介入以及评估其性能。你需要进行大量的反复试验。 在这篇博客〔原作者&a…...
