PyTorch 参数化深度解析:自定义、管理和优化模型参数
目录
torch.nn子模块parametrize
parametrize.register_parametrization
主要特性和用途
使用场景
参数和关键字参数
注意事项
示例
parametrize.remove_parametrizations
功能和用途
参数
返回值
异常
使用示例
parametrize.cached
功能和用途
如何使用
示例
parametrize.is_parametrized
功能和用途
参数
返回值
示例用法
parametrize.ParametrizationList
主要功能和特点
参数
方法
注意事项
示例
总结
torch.nn子模块parametrize
parametrize.register_parametrization
torch.nn.utils.parametrize.register_parametrization是PyTorch中的一个功能,它允许用户将自定义参数化方法应用于模块中的张量。这种方法对于改变和控制模型参数的行为非常有用,特别是在需要对参数施加特定的约束或转换时。
主要特性和用途
- 自定义参数化: 通过将参数或缓冲区与自定义的nn.Module相关联,可以对其行为进行自定义。
- 原始和参数化的版本访问: 注册后,可以通过module.parametrizations.[tensor_name].original访问原始张量,并通过module.[tensor_name]访问参数化后的版本。
- 支持链式参数化: 可以通过在同一属性上注册多个参数化来串联它们。
- 缓存系统: 内置缓存系统,可以使用cached()上下文管理器来激活,以提高效率。
- 自定义初始化: 通过实现right_inverse方法,可以自定义参数化的初始值。
使用场景
- 强制张量属性: 如强制权重矩阵为对称、正交或具有特定秩。
- 正则化和约束: 在训练过程中自动应用特定的正则化或约束。
- 模型复杂性控制: 例如,限制模型的参数数量或结构,以避免过拟合。
参数和关键字参数
- module(nn.Module): 需要注册参数化的模块。
- tensor_name(str): 需要进行参数化的参数或缓冲区的名称。
- parametrization(nn.Module): 将要注册的参数化。
- unsafe(bool, 可选): 表示参数化是否可能改变张量的数据类型和形状。默认为False。
注意事项
- 兼容性和安全性: 如果设置了unsafe=True,则在注册时不会检查参数化的一致性,这可能带来风险。
- 优化器兼容性: 如果在创建优化器后注册了新的参数化,可能需要手动将新参数添加到优化器中。
- 错误处理: 如果模块中不存在名为tensor_name的参数或缓冲区,将抛出ValueError。
示例
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as P# 定义一个对称矩阵参数化
class Symmetric(nn.Module):def forward(self, X):return X.triu() + X.triu(1).Tdef right_inverse(self, A):return A.triu()# 应用参数化
m = nn.Linear(5, 5)
P.register_parametrization(m, "weight", Symmetric())
print(torch.allclose(m.weight, m.weight.T))  # 现在m.weight是对称的# 初始化对称权重
A = torch.rand(5, 5)
A = A + A.T
m.weight = A
print(torch.allclose(m.weight, A))
这个示例创建了一个线性层,对其权重应用了对称性参数化,然后初始化权重为一个对称矩阵。通过这种方法,可以确保模型的权重始终保持特定的结构特性。
parametrize.remove_parametrizations
torch.nn.utils.parametrize.remove_parametrizations 是 PyTorch 中的一个功能,它用于移除模块中某个张量上的参数化。这个函数允许用户将模块中的参数从参数化状态恢复到原始状态,根据leave_parametrized参数的设置,可以选择保留当前参数化的输出或恢复到未参数化的原始张量。
功能和用途
- 移除参数化: 当不再需要特定的参数化或者需要将模型恢复到其原始状态时,此功能非常有用。
- 灵活性: 提供了在保留参数化输出和恢复到原始状态之间选择的灵活性。
参数
- module(nn.Module): 从中移除参数化的模块。
- tensor_name(str): 要移除参数化的张量的名称。
- leave_parametrized(bool, 可选): 是否保留属性- tensor_name作为参数化的状态。默认为True。
返回值
- 返回经修改的模块(Module类型)。
异常
- 如果module[tensor_name]未被参数化,会抛出ValueError。
- 如果leave_parametrized=False且参数化依赖于多个张量,也会抛出ValueError。
使用示例
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as P# 定义模块和参数化
m = nn.Linear(5, 5)
P.register_parametrization(m, "weight", ...)# 假设在这里进行了一些操作# 移除参数化,保留当前参数化的输出
P.remove_parametrizations(m, "weight", leave_parametrized=True)# 或者,移除参数化,恢复到原始未参数化的张量
P.remove_parametrizations(m, "weight", leave_parametrized=False)
 这个示例展示了如何在一个线性层上注册并最终移除参数化。根据leave_parametrized的设置,可以选择在移除参数化后保留当前的参数化状态或恢复到原始状态。这使得在模型开发和实验过程中可以更灵活地控制参数的行为。
parametrize.cached
torch.nn.utils.parametrize.cached() 是 PyTorch 框架中的一个上下文管理器,用于启用通过 register_parametrization() 注册的参数化对象的缓存系统。当这个上下文管理器活跃时,参数化对象的值在第一次被请求时会被计算和缓存。离开上下文管理器时,缓存的值会被丢弃。
功能和用途
- 性能优化: 当在前向传播中多次使用参数化参数时,启用缓存可以提高效率。这在参数化对象需要频繁计算但在单次前向传播中不变时特别有用。
- 权重共享场景: 在共享权重的情况下(例如,RNN的循环核),可以防止重复计算相同的参数化结果。
如何使用
- 通过将模型的前向传播包装在 P.cached()的上下文管理器内来激活缓存。
- 可以选择只包装使用参数化张量多次的模块部分,例如RNN的循环。
示例
import torch.nn as nn
import torch.nn.utils.parametrize as Pclass MyModel(nn.Module):# 模型定义...model = MyModel()
# 应用一些参数化
...# 使用缓存系统包装模型的前向传播
with P.cached():output = model(inputs)# 或者,仅在特定部分使用缓存
with P.cached():for x in xs:out_rnn = self.rnn_cell(x, out_rnn)
这个示例展示了如何在模型的整个前向传播过程中或者在特定部分(如RNN循环中)使用缓存系统。这样做可以在保持模型逻辑不变的同时,提高计算效率。特别是在复杂的参数化场景中,这可以显著减少不必要的重复计算。
parametrize.is_parametrized
torch.nn.utils.parametrize.is_parametrized 是 PyTorch 库中的一个函数,用于检查一个模块是否有活跃的参数化,或者指定的张量名称是否已经被参数化。
功能和用途
- 检查参数化状态: 用于确定给定的模块或其特定属性(如权重或偏置)是否已经被参数化。
- 辅助开发和调试: 在开发复杂的神经网络模型时,此函数可以帮助开发者了解模型的当前状态,特别是在使用自定义参数化时。
参数
- module(nn.Module): 要查询的模块。
- tensor_name(str, 可选): 模块中要查询的属性,默认为None。如果提供,函数将检查此特定属性是否已经被参数化。
返回值
- 返回类型为bool,表示指定模块或属性是否已经被参数化。
示例用法
import torch.nn as nn
import torch.nn.utils.parametrize as Pclass MyModel(nn.Module):# 模型定义...model = MyModel()
# 对模型的某个属性应用参数化
P.register_parametrization(model, 'weight', ...)# 检查整个模型是否被参数化
is_parametrized = P.is_parametrized(model)
print(is_parametrized)  # 输出 True 或 False# 检查模型的特定属性是否被参数化
is_weight_parametrized = P.is_parametrized(model, 'weight')
print(is_weight_parametrized)  # 输出 True 或 False
在这个示例中,is_parametrized 函数用来检查整个模型是否有任何参数化,以及模型的weight属性是否被特定地参数化。这对于验证参数化是否正确应用或在调试过程中理解模型的当前状态非常有用。
parametrize.ParametrizationList
ParametrizationList 是 PyTorch 中的一个类,它是一个顺序容器,用于保存和管理经过参数化的 torch.nn.Module 的原始参数或缓冲区。当使用 register_parametrization() 对模块中的张量进行参数化时,这个容器将作为 module.parametrizations[tensor_name] 的类型存在。
主要功能和特点
- 保存和管理参数: ParametrizationList保存了原始的参数或缓冲区,这些参数或缓冲区通过参数化被修改。
- 支持多重参数化: 如果首次注册的参数化有一个返回多个张量的 right_inverse方法,这些张量将以original0,original1, … 等的形式被保存。
参数
- modules(sequence): 代表参数化的模块序列。
- original(Parameter or Tensor): 被参数化的参数或缓冲区。
- unsafe(bool): 表明参数化是否可能改变张量的数据类型和形状。默认为False。当- unsafe=True时,不会在注册时检查参数化的一致性,使用时需要小心。
方法
- right_inverse(value): 按照注册的相反顺序调用参数化的- right_inverse方法。然后,如果- right_inverse输出一个张量,就将结果存储在- self.original中;如果输出多个张量,就存储在- self.original0,- self.original1, … 中。
注意事项
- 这个类主要由 register_parametrization()内部使用,并不建议用户直接实例化。
- unsafe参数的使用需要谨慎,因为它可能带来一致性问题。
示例
由于 ParametrizationList 主要用于内部实现,因此一般不会直接在用户代码中创建实例。它在进行参数化操作时自动形成,例如:
import torch.nn as nn
import torch.nn.utils.parametrize as P# 定义一个简单的模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = nn.Linear(10, 10)model = MyModel()# 对模型的某个参数应用参数化
P.register_parametrization(model.linear, "weight", MyParametrization())# ParametrizationList 实例可以通过以下方式访问
param_list = model.linear.parametrizations.weight
 在这个示例中,param_list 将是 ParametrizationList 类的一个实例,包含了 weight 参数的所有参数化信息。
总结
本篇博客探讨了 PyTorch 中 torch.nn.utils.parametrize 子模块的强大功能和灵活性。它详细介绍了如何通过自定义参数化(register_parametrization)来改变和控制模型参数的行为,提供了移除参数化(remove_parametrizations)的方法以恢复模型到原始状态,并探讨了如何利用缓存机制(cached)来提高参数化参数在前向传播中的计算效率。此外,文章还解释了如何检查模型或其属性的参数化状态(is_parametrized),并深入了解了 ParametrizationList 类在内部如何管理参数化参数。
相关文章:
PyTorch 参数化深度解析:自定义、管理和优化模型参数
目录 torch.nn子模块parametrize parametrize.register_parametrization 主要特性和用途 使用场景 参数和关键字参数 注意事项 示例 parametrize.remove_parametrizations 功能和用途 参数 返回值 异常 使用示例 parametrize.cached 功能和用途 如何使用 示例…...
 
自承载 Self-Host ASP.NET Web API 1 (C#)
本教程介绍如何在控制台应用程序中托管 Web API。 ASP.NET Web API不需要 IIS。 可以在自己的主机进程中自托管 Web API。 创建控制台应用程序项目 启动 Visual Studio,然后从“开始”页中选择“新建项目”。 或者,从“ 文件 ”菜单中选择“ 新建 ”&a…...
Vue2-子传父和父传子的基本用法
在Vue 2中,可以使用props和$emit来实现子组件向父组件传值(子传父)和父组件向子组件传值(父传子)。 子传父(子组件向父组件传值)的基本用法如下: 在父组件中定义一个属性ÿ…...
 
使用numpy处理图片——镜像翻转和旋转
在《使用numpy处理图片——基础操作》一文中,我们介绍了如何使用numpy修改图片的透明度。本文我们将介绍镜像翻转和旋转。 镜像翻转 上下翻转 from PIL import Image import numpy as np img Image.open(example.png) data np.array(img)# axis0 is vertical, a…...
 
HTML5 article标签,<time>...</time>标签和pubdate属性的运用
1、<article>...</article>标签的运用 article标签代表文档、页面或应用程序中独立的、完整的、可以独自被外部引用的内容。它可以是一篇博客或报竟杂志中的文章、一篇论坛帖子、一段用户评论或一个独立的插件,或者其他任何独立的内容。把文章正文放在h…...
 
Amazing OpenAI API:把非 OpenAI 模型都按 OpenAI API 调用
分享一个有趣的小工具,10MB 身材的小工具,能够将各种不同的模型 API 转换为开箱即用的 OpenAI API 格式。 让许多依赖 OpenAI API 的软件能够借助开发者能够接触到的,非 OpenAI 的 API 私有部署和使用起来。 写在前面 这个小工具软件写于两…...
 
RK3568平台开发系列讲解(驱动篇)pinctrl 函数操作集结构体讲解
🚀返回专栏总目录 文章目录 一、pinctrl_ops二、pinmux_ops三、pinconf_ops沉淀、分享、成长,让自己和他人都能有所收获!😄 pinctrl_ops:提供有关属于引脚组的引脚的信息。pinmux_ops:选择连接到该引脚的功能。pinconf_ops:设置引脚属性(上拉,下拉,开漏,强度等)。…...
vue购物车案例,v-model 之 lazy、number、trim,与后端交互
购物车案例 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><script src"./js/vue.js"></script> </head> <body> <div id"d1"&…...
 
云原生Kubernetes: Kubeadm部署K8S 1.29版本 单Master架构
目录 一、实验 1.环境 2.K8S master节点环境准备 3.K8S master节点安装kubelet、kubeadm、kubectl 3.K8S node节点环境准备与软件安装 4.K8S master节点部署服务 5.K8S node节点部署 6.K8S master节点查看集群 7.容器网络(CNI)部署 8.K8S 集群…...
 
C++协程操作
什么是C++协程 C++中的协程是一种用户态轻量级线程,它拥有自己的上下文和栈,并且协程的切换和调度由用户定义,不需要陷入内核。如同一个进程可以拥有多个线程,一个线程也可以拥有多个协程。协程的优点在于极高的执行效率,因为协程切换不需要陷入内核,而是由用户程序定义切…...
 
计算机配件杂谈-鼠标
目录 基础知识鼠标的发展鼠标的左右手鼠标的显示样式鼠标的移动和可见性移动可见性 现在的我们的生活工作都基本上离不开电脑了,不管是你平时玩玩游戏,上班工作等等; 今天将关于鼠标的一些小的技巧分享出来,共勉! 基础…...
 
用Python来制作一个微信聊天机器人
1. 效果展示 通过本地搭建一个flask服务器来接收信息,这里我简单使用展示,就没有对接收的信息进行处理了。 信息接收展示 发送信息展示 这里就直接使用python发送一个post请求即可,可以发送文字或者图片 代码展示 接收信息 #!/usr/bin/e…...
 
2024年第九届机器学习技术国际会议(ICMLT 2024) 即将召开
2024年第九届机器学习技术国际会议(ICMLT 2024)将于2024年5月24-26日在挪威奥斯陆举行。ICMLT 2024旨在讨论机器学习技术领域的最新研究技术现状和前沿趋势,为来自世界各地的科学家、工程师、实业家、学者和其他专业人士提供一个互动和交流的…...
算法训练day9Leetcode232用栈实现队列225用队列实现栈
今天学习的文章和视频链接 https://programmercarl.com/%E6%A0%88%E4%B8%8E%E9%98%9F%E5%88%97%E7%90%86%E8%AE%BA%E5%9F%BA%E7%A1%80.html 栈与队列理论基础 见我的博客 https://blog.csdn.net/qq_36372352/article/details/135470438?spm1001.2014.3001.5501 232用栈实现…...
 
linux驱动(四):platform
本文主要探讨x210驱动的平台设备类型(platform)以及misc设备。 驱动模型 设备驱动模型:总线(bus type)、设备(device)和驱动(driver) 总线:虚拟总线用于挂接驱动驱动和设备 总线、设备、驱动关系:/sys/bus下的子目录…...
 
Guava:Cache强大的本地缓存框架
Guava Cache是一款非常优秀的本地缓存框架。 一、 经典配置 Guava Cache 的数据结构跟 JDK1.7 的 ConcurrentHashMap 类似,提供了基于时间、容量、引用三种回收策略,以及自动加载、访问统计等功能。 基本的配置 Testpublic void testLoadingCache() th…...
#{}和${}的区别?
#{}是占位符,预编译处理;${}是拼接符,字符串替换,没有预编译处理。Mybatis在处理#{}时,#{}传入参数是以字符串传入,会将SQL中的#{}替换为?号,调用PreparedStatement的set方法来赋值。Mybatis在…...
 
string的模拟实现
string的模拟实现 msvc和g下的string内存比较成员变量构造函数与析构函数拷贝构造函数赋值拷贝c_str、size和capacity函数以及重载[]、clear、expand_capacity迭代器与遍历reservepush_back、append、insert字符串比较运算符erase<<流提取 >>流插入resizefindsubst…...
算法练习:查找二维数组中的目标值
题目: 编写一个高效的算法来搜索矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性:每行的元素从左到右升序排列。每列的元素从上到下升序排列。 实现: 1. main方法 public static void main(String[] args) {int[][] matrix {{1…...
考研自命题资料、考题如何找
这篇文章是抖音和b站上上传的同名视频的原文稿件,感兴趣的csdn用户可以关注我的抖音和b站账号(GeekPower极客力量)。同时这篇文章也为视频观众提供方便,可以更加冷静地分析和思考。文章同时在知乎发表。 去年我发布了一个视频&am…...
利用ngx_stream_return_module构建简易 TCP/UDP 响应网关
一、模块概述 ngx_stream_return_module 提供了一个极简的指令: return <value>;在收到客户端连接后,立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量(如 $time_iso8601、$remote_addr 等)&a…...
【C语言练习】080. 使用C语言实现简单的数据库操作
080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...
MySQL用户和授权
开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务: test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...
 
如何理解 IP 数据报中的 TTL?
目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...
 
用机器学习破解新能源领域的“弃风”难题
音乐发烧友深有体会,玩音乐的本质就是玩电网。火电声音偏暖,水电偏冷,风电偏空旷。至于太阳能发的电,则略显朦胧和单薄。 不知你是否有感觉,近两年家里的音响声音越来越冷,听起来越来越单薄? —…...
基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解
JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用,结合SQLite数据库实现联系人管理功能,并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能,同时可以最小化到系统…...
MySQL 主从同步异常处理
阅读原文:https://www.xiaozaoshu.top/articles/mysql-m-s-update-pk MySQL 做双主,遇到的这个错误: Could not execute Update_rows event on table ... Error_code: 1032是 MySQL 主从复制时的经典错误之一,通常表示ÿ…...
 
协议转换利器,profinet转ethercat网关的两大派系,各有千秋
随着工业以太网的发展,其高效、便捷、协议开放、易于冗余等诸多优点,被越来越多的工业现场所采用。西门子SIMATIC S7-1200/1500系列PLC集成有Profinet接口,具有实时性、开放性,使用TCP/IP和IT标准,符合基于工业以太网的…...
 
水泥厂自动化升级利器:Devicenet转Modbus rtu协议转换网关
在水泥厂的生产流程中,工业自动化网关起着至关重要的作用,尤其是JH-DVN-RTU疆鸿智能Devicenet转Modbus rtu协议转换网关,为水泥厂实现高效生产与精准控制提供了有力支持。 水泥厂设备众多,其中不少设备采用Devicenet协议。Devicen…...
在RK3588上搭建ROS1环境:创建节点与数据可视化实战指南
在RK3588上搭建ROS1环境:创建节点与数据可视化实战指南 背景介绍完整操作步骤1. 创建Docker容器环境2. 验证GUI显示功能3. 安装ROS Noetic4. 配置环境变量5. 创建ROS节点(小球运动模拟)6. 配置RVIZ默认视图7. 创建启动脚本8. 运行可视化系统效果展示与交互技术解析ROS节点通…...
