TensorRT量化工具pytorch_quantization代码解析(二)
有些地方看的不是透彻,后续继续补充!
继续看张量量化函数,代码位于:tools\pytorch-quantization\pytorch_quantization\tensor_quant.py
ScaledQuantDescriptor
量化的支持描述符:描述张量应该如何量化。QuantDescriptor和张量定义了量化张量。
class ScaledQuantDescriptor():def __init__(self, num_bits=8, name=None, **kwargs):if not isinstance(num_bits, int):raise TypeError("num_bits must be an integer, not {}.".format(type(num_bits)))if num_bits < 0:raise ValueError("num_bits must be >= 0, not {}.".format(num_bits))if num_bits == 0:logging.error("num_bits is 0. This will result in the tensor being quantized to all zeros."" This mode should only be used for debugging purposes.")self._num_bits = num_bitsif not isinstance(name, str) and name is not None:raise TypeError("name must be a string or None, not {}.".format(type(name)))self._name = nameself._fake_quant = kwargs.pop('fake_quant', True)self._axis = kwargs.pop('axis', None)if self._axis is not None:logging.debug("Meaning of axis has changed since v2.0. Make sure to update.")self._learn_amax = kwargs.pop('learn_amax', False)if self._learn_amax and self._axis is not None:raise TypeError("axis is ignored and must be None when learn_amax is true, got {}.".format(type(self._axis)))amax = kwargs.pop('amax', None)if amax is not None:if not isinstance(amax, float) and not isinstance(amax, list) and not isinstance(amax, np.ndarray):raise TypeError("amax must be float, list or ndarray, not {}".format(type(amax)))# Make it single precision arrayself._amax = np.array(amax, dtype=np.float32)else:self._amax = amaxself._scale_amax = kwargs.pop('scale_amax', None)self._calib_method = kwargs.pop('calib_method', "max")self._unsigned = kwargs.pop('unsigned', False)self._narrow_range = kwargs.pop('narrow_range', False)if kwargs:raise TypeError("Unused keys: {}".format(kwargs.keys()))
参数:
- num_bits:int,量化位数,用于计算比例因子。默认值8。
- name:看起来很不错
关键字参数:
-
fake_quant
:布尔值。如果为True,则使用fake
量化模式。默认为True -
axis
:None, int
或整数的tuple
,轴将利用自己的最大值以计算缩放因子,默认None。- 如果None(默认值),则使用
per tensor scale
。
确保在范围[-rank(input_tensor),rank(输入_tensor))内。
例如,对于KCRS权重张量,quant_axis=(0)
将产生per channel scaling
。
- 如果None(默认值),则使用
-
amax
:用户指定的绝对最大范围的float
或list/ndarray
。如果提供,忽略quant_axis
并使用它进行量化。如果learn_amax
为True,将用于初始化可学习的amax
。默认None
-
learn_amax
:boolean
,如果为True,学习amax。默认为False。 -
scale_amax
:float
,如果提供,将amax
乘以scale_amax
,默认无。 -
calib_method
:string
,[“max”,“histogram”]
中的一个校准要使用的指标。除了
max calibration
,其他都是基于hisogram
的。默认值“max”。 -
unsigned
:Boolean
,如果为True
,则使用无符号。默认为False
。
Raises:
- TypeError:如果传入了不支持的类型。
Read-only properties:
fake_quant:
name:
learn_amax:
scale_amax:
axis:
calib_method:
num_bits:
amax:
unsigned:
QuantDescriptor定义了张量应该如何量化。预定义的QuantDescriptor张量描述符如下:
QuantDescriptor = ScaledQuantDescriptor# Predefined descriptors
QUANT_DESC_8BIT_PER_TENSOR = QuantDescriptor(num_bits=8)
QUANT_DESC_UNSIGNED_8BIT_PER_TENSOR = QuantDescriptor(num_bits=8, unsigned=True)
QUANT_DESC_8BIT_CONV1D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONV3D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONVTRANSPOSE1D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONVTRANSPOSE2D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
QUANT_DESC_8BIT_CONVTRANSPOSE3D_WEIGHT_PER_CHANNEL = QuantDescriptor(num_bits=8, axis=(0))
如果在QuantDescriptor
中给出最amax
,TensorQuantizer
将使用它进行量化。否则,TensorQuantizer
将计算amax
,然后进行量化。amax
被计算通过指定的axis
轴。注意QuantDescriptor
将剩余轴指定与max()
轴相反。
例子:
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizerquant_desc = QuantDescriptor(num_bits=4, fake_quant=False, axis=(0), unsigned=True)
接下来看量化函数:pytorch_quantization
提供3个自定义的张量量化函数算子,继承torch.autograd.function
,实现函数的前向传播、反向传播
TensorQuantFunction
- 通用的张量量化函数
TensorQuantFunction
class TensorQuantFunction(Function):"""一个输入张量,输出一个量化张量。`scale`的粒度可以从amax的形状来解释"""
forward
在前向过程中,对浮点权重和激活进行伪量化,并使用这些伪量化的权重和激活来执行层的操作
@staticmethoddef forward(ctx, inputs, amax, num_bits=8, unsigned=False, narrow_range=True):ctx.save_for_backward(inputs, amax)outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)# Check if scale overflows FP16if outputs.dtype == torch.half and scale.max() > 65504:raise ValueError("scale is too large for FP16 with amax={}".format(amax))return outputs, scale.to(inputs.dtype)
output_dtype
指示量化值是以整数还是浮点形式存储。希望将其存储在浮点中的原因是pytorch函数接受量化值,它可能不接受整数输入,例如Conv2D
。
它使用2num_bits−12^{num\_bits-1}2num_bits−1值,例如,对于num_bits=8
,使用[-127,127]
遵循tensorflow约定,传入最大值并用于确定比例,而不是直接输入比例。尽管直接输入比例可能更自然。
参数:
-
ctx
:一个用于向后存储张量的Context对象。 -
inputs
:float32型张量。 -
amax
:float32型张量。输入将在[-amax,amax]范围内量化,amax将广播到inputs tensor。 -
num_bits
:用于计算缩放因子的整数,scale=(2num_bits−1−1)/maxscale=(2^{num\_bits-1}-1)/maxscale=(2num_bits−1−1)/max。默认值8。 -
output_dtype
:张量的一种类型。torch.int32或torch.float32。希望存储为float,pytorch函数接受float量化值,它可能不接受整数输入。
unsigned:boolean,使用无符号整数范围。例如,对于num_bits=8,[0,255]。默认为False。 -
narrow_range
:布尔值。使用对称整数范围进行有符号量化
例如,对于num_bits=8,用[-127,127]代替[-128,127]。默认为True。
Returns
:
-
outputs
:output_dtype
类型的张量。 -
scale
:float32
型张量。outputs / scale
将对输出张量进行反量化。
Raises
:
ValueError
:
backward
通过clipping
实现直通估计。对于-amax<=input<=amax
,梯度直接通过,否则梯度为零。
参数:
ctx
:一个上下文对象,其中保存了来自forward的张量。grad_outputs
:outputs梯度张量。grad_scale
:scale梯度张量。
Returns:
grad_inputs
:梯度张量。
@staticmethoddef backward(ctx, grad_outputs, grad_scale):"""Implements straight through estimation with clipping. For -amax <= input <= amaxthe gradient passes straight through, otherwise the gradient is zero.Args:ctx: A Context object with saved tensors from forward.grad_outputs: A tensor of gradient of outputs.grad_scale: A tensor of gradient of scale.Returns:grad_inputs: A tensor of gradient."""inputs, amax = ctx.saved_tensorszero = grad_outputs.new_zeros(1) # create a zero tensor with the same type and devicegrad_inputs = torch.where(inputs.abs() <= amax, grad_outputs, zero)return grad_inputs, None, None, None, None
tensor_quant = TensorQuantFunction.apply
给TensorQuantFunction.apply
赋予一个别名tensor_quant
,这样可以直接调用tensor_quant
进行量化,例如:
from pytorch_quantization import tensor_quant# Generate random input. With fixed seed 12345, x should be
# tensor([0.9817, 0.8796, 0.9921, 0.4611, 0.0832, 0.1784, 0.3674, 0.5676, 0.3376, 0.2119])
torch.manual_seed(12345)
x = torch.rand(10)# quantize tensor x. quant_x will be
# tensor([126., 113., 127., 59., 11., 23., 47., 73., 43., 27.])
# with scale=128.0057
quant_x, scale = tensor_quant.tensor_quant(x, x.abs().max())
FakeTensorQuantFunction
class FakeTensorQuantFunction(Function):"""Fake version of TensorQuantFunctionSee comments of TensorQuantFunction, arguments are the same."""@staticmethoddef forward(ctx, inputs, amax, num_bits=8, unsigned=False, narrow_range=True):ctx.save_for_backward(inputs, amax)outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)return outputs / scale.to(inputs.dtype)@staticmethoddef backward(ctx, grad_outputs):inputs, amax = ctx.saved_tensorszero = grad_outputs.new_zeros(1)grad_inputs = torch.where(inputs.abs() <= amax, grad_outputs, zero)return grad_inputs, None, None, None, None
在向后过程中,使用权重的渐变来更新浮点权重。为了处理量化梯度,除了未定义的点之外,几乎所有地方都是零,可以使用 直通估计器 ( STE ),它通过伪量化操作符传递梯度。
fake_tensor_quant = FakeTensorQuantFunction.apply
给TensorQuantFunction.apply
赋予一个别名fake_tensor_quant
,这样可以直接调用fake_tensor_quant
进行量化,例如:
from pytorch_quantization import tensor_quant# Generate random input. With fixed seed 12345, x should be
# tensor([0.9817, 0.8796, 0.9921, 0.4611, 0.0832, 0.1784, 0.3674, 0.5676, 0.3376, 0.2119])
torch.manual_seed(12345)
x = torch.rand(10)# fake quantize tensor x. fake_quant_x will be
# tensor([0.9843, 0.8828, 0.9921, 0.4609, 0.0859, 0.1797, 0.3672, 0.5703, 0.3359, 0.2109])
fake_quant_x = tensor_quant.fake_tensor_quant(x, x.abs().max())
_tensor_quant
def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):"""Shared function body between TensorQuantFunction and FakeTensorQuantFunction"""# Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning.if isinstance(amax, torch.Tensor) and inputs.dim() != amax.dim():logging.debug("amax %s has different shape than inputs %s. Make sure broadcast works as expected!",amax.size(), inputs.size())logging.debug("{} bits quantization on shape {} tensor.".format(num_bits, inputs.size()))if unsigned:if inputs.min() < 0.:raise TypeError("Negative values encountered in unsigned quantization.")# Computation must be in FP32 to prevent potential over flow.input_dtype = inputs.dtypeif inputs.dtype == torch.half:inputs = inputs.float()if amax.dtype == torch.half:amax = amax.float()min_amax = amax.min()if min_amax < 0:raise ValueError("Negative values in amax")max_bound = torch.tensor((2.0**(num_bits - 1 + int(unsigned))) - 1.0, device=amax.device)if unsigned:min_bound = 0elif narrow_range:min_bound = -max_boundelse:min_bound = -max_bound - 1scale = max_bound / amaxepsilon = 1. / (1<<24)if min_amax <= epsilon: # Treat amax smaller than minimum representable of fp16 0zero_amax_mask = (amax <= epsilon)scale[zero_amax_mask] = 0 # Value quantized with amax=0 should all be 0outputs = torch.clamp((inputs * scale).round_(), min_bound, max_bound)if min_amax <= epsilon:scale[zero_amax_mask] = 1. # Return 1 makes more sense for values quantized to 0 with amax=0if input_dtype == torch.half:outputs = outputs.half()return outputs, scale
待梳理!!!
相关文章:
TensorRT量化工具pytorch_quantization代码解析(二)
有些地方看的不是透彻,后续继续补充! 继续看张量量化函数,代码位于:tools\pytorch-quantization\pytorch_quantization\tensor_quant.py ScaledQuantDescriptor 量化的支持描述符:描述张量应该如何量化。QuantDescriptor和张量…...
buu [BJDCTF2020]easyrsa 1
题目描述 : from Crypto.Util.number import getPrime,bytes_to_long from sympy import Derivative from fractions import Fraction from secret import flagpgetPrime(1024) qgetPrime(1024) e65537 np*q zFraction(1,Derivative(arctan(p),p))-Fraction(1,Deri…...
taobao.user.openuid.getbyorder( 根据订单获取买家openuid )
¥免费不需用户授权 根据订单获取买家openuid,最大查询30个 公共参数 请求地址: HTTP地址 http://gw.api.taobao.com/router/rest 公共请求参数: 请求示例 TaobaoClient client new DefaultTaobaoClient(url, appkey, secret); UserOpenuidGetbyorderR…...
Mac iTerm2 rz sz
1、安装brew(找了很多🔗,就这个博主的好用) Mac如何安装brew?_行走的码农00的博客-CSDN博客_mac brew 2、安装lrzsz brew install lrzsz 检查是否安装成功 brew list 定位lrzsz的安装目录 brew list lrzsz 执…...
高通平台开发系列讲解(Sensor篇)Gsensor基础知识
文章目录 一、什么是SENSOR?二、Sensor的分类及作用三、Gsensor的工作原理及介绍3.1、常见Gsensor3.2、Gsensor的特性沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇文章将介绍 Sensor 基础 一、什么是SENSOR? 传感器(英文名称:sensor )是一种检测装置,能感…...
图像处理实战--Opencv实现人像迁移
前言: Hello大家好,我是Dream。 今天来学习一下如何使用Opencv实现人像迁移,欢迎大家一起参与探讨交流~ 本文目录:一、实验要求二、实验环境三、实验原理及操作1.照片准备2.图像增强3.实现美颜功能4.背景虚化5.图像二值化处理6.人…...
OnlyOffice验证(二)在Centos7上部署OnlyOffice编译结果
在Centos7上部署OnlyOffice编译结果 此处将尝试将OnlyOffice验证(一)DocumentServer编译验证的结果部署到Centos7上。并且使用其它服务器现有的RabbitMq和Mysql。 安装Nginx 先安装Nginx需要的依赖环境: yum install openssl* -y yum insta…...
6.补充和总结【Java面试第三季】
6.补充和总结【Java面试第三季】前言推荐6.补充和总结69_总结闲聊回顾和总结继续学习最后前言 2023-2-4 19:08:01 以下内容源自 【尚硅谷Java大厂面试题第3季,跳槽必刷题目必扫技术盲点(周阳主讲)-哔哩哔哩】 仅供学习交流使用 推荐 Jav…...
基于ssm框架大学生社团管理系统(源码+数据库+文档)
一、项目简介 本项目是一套基于ssm框架大学生社团管理系统,主要针对计算机相关专业的正在做bishe的学生和需要项目实战练习的Java学习者。 包含:项目源码、数据库脚本等,该项目可以直接作为bishe使用。 项目都经过严格调试,确保可…...
vulnhub靶场NAPPING: 1.0.1教程
靶场搭建靶机下载地址:Napping: 1.0.1 ~ VulnHub直接解压双击ova文件即可使用软件:靶机VirtualBox,攻击机VMware攻击机:kali信息收集arp-scan -l上帝之眼直接来看看网站可以注册账号,那就先试试。注册完后登入哦。要输…...
Docker基本介绍
最近需要将项目做成一个web应用并部署到多台服务器上,于是就简单学习了一下docker,做一下小小的记录。 1、简单介绍一下docker 我们经常遇到这样一个问题,自己写的代码在自己的电脑上运行的很流畅,在其他人电脑上就各种bug&…...
可用于标记蛋白质216699-36-4,6-ROX,SE,6-羧基-X-罗丹明琥珀酰亚胺酯
一.6-ROX,SE产品描述:6-羧基-X-罗丹明琥珀酰亚胺酯(6-ROX,SE)是一种用于寡核苷酸标记和自动DNA测序的荧光染料,可用于标记蛋白质,寡核苷酸和其他含胺分子的伯胺(-NH2)。西…...
高数:极限的定义
目录 极限的定义: 数列极限的几何意义: 由极限的定义得出的极限的两个结论: 编辑 极限的第三个结论: 例题 方法1: 编辑 方法2: 编辑 方法3: 编辑 极限的定义: 如何理…...
大数据技术之Hadoop
第1章 Hadoop概述1.1 Hadoop是什么1.2 Hadoop发展历史(了解)1.3 Hadoop三大发行版本(了解)Hadoop三大发行版本:Apache、Cloudera、Hortonworks。Apache版本最原始(最基础)的版本,对于…...
一文带你搞懂Go语言函数选项模式,Go函数一等公民。
前言 通过这篇文章《为什么说Go的函数是”一等公民“》,我们了解到了什么是“一等公民”,以及都具备哪些特性,同时对函数的基本使用也更加深入。 本文重点介绍下Go设计模式之函数选项模式,它得益于Go的函数是“一等公民”&#…...
Window.location 详细介绍
如果你需要获取网站的 URL 信息,那么 window.location 对象就是为你准备的。使用它提供的属性来获取当前页面地址的信息,或使用其方法进行某些页面的重定向或刷新。 https://www.samanthaming.com/tidbits/?filterJS#2 window.location.origin → htt…...
js侧滑显示删除按钮
效果图: <!DOCTYPE html> <html><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0, maximum-scale1.0, user-scalableno"><title>js侧滑显示删…...
Python - DIY - 使用dump取json某些键值对合成新的json文件
Python - Json处理前言:应用场景:基本工具:文件操作:打开文件:写文件:读文件:关闭文件并刷新缓冲区:Json字符串和字典转换:json.loads():json.dumps():Json文…...
深度剖析指针(中)——“C”
各位CSDN的uu们你们好呀,今天小雅兰的内容仍旧是深度剖析指针噢,在上一篇博客中,我已经写过了字符指针、数组指针、指针数组、数组传参和指针传参的知识点,那么这篇博客小雅兰会讲解一下函数指针、函数指针数组 、指向函数指针数组…...
论文阅读 | Video Frame Synthesis using Deep Voxel Flow
前言: 视频帧生成方法(视频插帧/视频预测)ICCV2017 oral Video Frame Synthesis using Deep Voxel Flow 引言 当下进行视频帧合成的方法分为两种,第一种是光流法,光流准确的话效果好,光流不准确的话则生…...
我所理解的生活
诞生 人真正意义上的诞生应该是社会学意义上的,是一种意识到自我、自我与社会关系的存在,只有这种诞生,才是完整人生的基点,大千世界中,唯有人类以生活作为自己的存在方式,除人类以外,从无机界…...
debian 部署nginx https
我是flask 处理请求单进程, 差点意思 , 考虑先flask 在往下走 一:安装nginx 因为我是debian 系统,所以我的建议是直接 sudo apt-get install nginx 你也可以选择在官网下载, 但是我搭建ssl 的时候安装openssl非常的麻…...
SQL 层功能改进 - lookupJoin 的优化
一、传统 join 算法lookupJoin 是 join 查询的一种,传统 join 算法为:1. 遍历 A 表,读取一条数据 r2. 遍历 B 表,对于每条数据,与 r 进行 join 操作3. 重复 1、2 操作,直到 A 表遍历完所有数据二、lookupJo…...
动态规划:鸣人的影分身
在火影忍者的世界里,令敌人捉摸不透是非常关键的。我们的主角漩涡鸣人所拥有的一个招数——多重影分身之术——就是一个很好的例子。影分身是由鸣人身体的查克拉能量制造的,使用的查克拉越多,制造出的影分身越强。针对不同的作战情况…...
如何为三星active2手表安装自己DIY的表盘
一、步骤介绍 Step 1. 下载Galaxy watch studio; Step 2. 按照up主“隔壁张师傅2022”的文章进行安装。 二、安装流程简单说明: ① 电脑端官网下载并安装Galaxy Watch Designer或者Galaxy Watch Studio程序。 ② 关闭手表蓝牙连接,并打开调…...
Android 项目必备(四十二)-->Android 多窗口模式
简介 自由窗口模式: 该模式类似于常见的桌面操作系统, 应用界面的窗口可以自由的拖动和修改大小。 分屏模式 该模式可以在手机上使用, 该模式将屏幕一分为二, 同时显示两个应用界面。 画中画模式: 该模式主要用于TV, 在该模式下…...
OpenHarmony的未来和如何做好一个开源社区
今天要分享的文章,可能更多只是作为一种观点。主要包括2个内容。OpenHarmony的未来和如何做好一个开源社区,好的,接下来开始今天的内容。 你对OpenHarmony的未来如何看待? OpenHarmony的未来看起来非常光明,因为它具…...
二叉搜索树实现
树的导览 树由节点(nodes)和边(edges)构成,如下图所示。整棵树有一个最上端节点,称为根节点(root)。每个节点可以拥有具有方向的边(directed edges)…...
解决Spring Data Jpa 实体类自动创建数据库表失败问题
先说一下我遇到的这个问题,首先我是通过maven创建了一个spring boot的工程,引入了Spring data jpa,结果实体类创建好之后,运行工程却没有在数据库中自动创建数据表。 找了半天发现是一个配置的问题! hibernate.ddl-auto节点的配…...
Elasticsearch:创建一个简单的 “你的意思是?” 推荐搜索
“你的意思是” 是搜索引擎中一个非常重要的功能,因为它们通过显示建议的术语来帮助用户,以便他可以进行更准确的搜索。比如,在百度中,我们进行搜索时,它通常会显示一些更为常用推荐的搜索选项来供我们选择:…...
在哪个网站可以搜画画做品/商业软文案例
准备工作: 1、安装并配置Java运行环境 2、数据库的安装配置(MySql) 3、安装并配置服务器(Tomcat) 4、Maven 5、Eclipse安装配置 6、使用Eclipse创建web app项目 JAR包集成:(pom.xml) <project xml…...
网站设计师工作室/怎么做好网站搜索引擎优化
俄式乘法,又被称为俄国农夫法,它是对两个正整数相乘的非主流算法。假设m和n是两个正整数,我们要计算它们的积。它的主要原理如下: if n is 偶数 n * mn/2 * 2m else n * m(n-1)/2 * 2m m 该算法只包括折半,加倍&#…...
深圳网站建设 site/盘古百晋广告营销是干嘛
一般编程语言中喜欢用符号来判断java中两个字符串是否相等,例如c。c提供了操作符的重载,所以可以重载运算符来判断。 但是由于java中,没有提供运算符重载,而且java中没有提供基本的string类型、也没有把string看成char数组&#x…...
东莞企业网站优化/如何免费做网站网页
转自:https://www.pinlue.com/article/2018/11/2808/327696834012.html...
wordpress转为pdf/建网站教学
在javascript中,当我们需要知道一个变量的基本数据类型的时候可以使用typeof,结果可以是"undefined、"boolean"、"number"、"string"、"function"和"object",而需要检测引用类型的时…...
泰安网站制作哪里有/国外搜索引擎大全不屏蔽
《使用boost::atomic实现读写锁》,作者:芊芊水,原文链接:http://www.cnblogs.com/qianqians/archive/2013/02/18/2915068.html 分享自:博客园Android客户端(http://android.walkingp.com/cnblogs/)...