PyTorch源码系列(一)——Optimizer源码详解
目录
- 1. Optimizer类
- 2. Optimizer概览
- 3. 源码解析
- 3.1 构造方法
- 3.1.1 全局设置情形
- 3.1.2 局部设置情形
- 3.1.3 覆盖测试
- 3.1.4 逐行讲解
- 3.2 add_param_group
- 3.3 step
- 3.4 zero_grad
- 3.5 self.state
- 3.6 state_dict
- 3.7 load_state_dict
- 4. SGD Optimizer
- 5. 极简版Optimizer源码
- 6. 自定义你的Optimizer
- Ref
1. Optimizer类
PyTorch的 Optimizer
类是深度学习模型中用于管理和更新模型参数的基类。它负责根据损失函数的梯度信息调整模型的参数,使模型逐步逼近最佳状态。Optimizer
类通过实现一些核心方法,如 step()
,来执行参数更新过程,而 zero_grad()
方法则用于清除模型中所有参数的梯度。
每个优化器会存储参数组和相关的状态,例如学习率、动量等。不同的优化器(如SGD、Adam等)继承自 Optimizer
类,并根据各自的算法特点实现了不同的参数更新策略。此外,Optimizer
类还允许用户在初始化时指定超参数,如学习率等,这些超参数会影响参数的更新方式。
本文将详细讲解 Optimizer
类的源码,并以SGD优化器为例介绍如何自定义一个自己的优化器。
2. Optimizer概览
📝 本文在讲解源码时,只考虑源码的简化版本,而不考虑完整的源码。
除了构造方法外,Optimizer常用的几个方法如下:
class Optimizer:def state_dict(self) -> Dict[str, Any]:"""返回优化器的状态字典,保存当前的优化器状态以便之后恢复。"""...def load_state_dict(self, state_dict: Dict[str, Any]) -> None:"""加载先前保存的状态字典,恢复优化器的状态。"""...def zero_grad(self, set_to_none: bool = True) -> None:"""将优化器中所有参数的梯度清零,通常在每次反向传播前调用。"""...def step(self) -> None:"""执行一步优化更新,用于根据梯度更新模型参数。"""raise NotImplementedErrordef add_param_group(self, param_group: Dict[str, Any]) -> None:"""向优化器中添加新的参数组,用于管理不同的参数组(如不同学习率等)。"""...
在自定义优化器时,必须继承 Optimizer
类,并实现 step()
方法,否则将会报错。
optimizer.py
文件中还定义了若干类型别名,如下:
Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
Args
:通常用来表示函数中不定数量的位置参数。Kwargs
:通常用来表示函数中不定数量的关键字参数,键为参数名,值为相应的参数值。StateDict
:通常用于保存模型参数和优化器的状态信息。ParamsT
:是一个torch.Tensor
的可迭代对象,或者是包含键值对的字典的可迭代对象。
3. 源码解析
3.1 构造方法
Optimizer
的构造方法如下:
class Optimizer:r"""Base class for all optimizers... warning::Parameters need to be specified as collections that have a deterministicordering that is consistent between runs. Examples of objects that don'tsatisfy those properties are sets and iterators over values of dictionaries.Args:params (iterable): an iterable of :class:`torch.Tensor` s or:class:`dict` s. Specifies what Tensors should be optimized.defaults: (dict): a dict containing default values of optimizationoptions (used when a parameter group doesn't specify them)."""def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:self.defaults = defaultsif isinstance(params, torch.Tensor):raise TypeError("params argument given to the optimizer should be ""an iterable of Tensors or dicts, but got " + torch.typename(params))self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)self.param_groups: List[Dict[str, Any]] = []param_groups = list(params)if len(param_groups) == 0:raise ValueError("optimizer got an empty parameter list")if not isinstance(param_groups[0], dict):param_groups = [{"params": param_groups}]for param_group in param_groups:self.add_param_group(param_group)
从源码可以看出,Optimizer
有两个形参和两个属性(self.defaults
不算)。
形参:
params
:由一系列Tensor组成的迭代器或是由一系列字典组成的迭代器。通常是模型的参数。例如model.parameters()
。defaults
:一个键为字符串的字典。通常是和优化算法相关的全局超参数。例如lr
、momentum
等。下文会解释为什么这里是「全局」。
属性:
state
:一个键为Tensor,值为字典的字典。用来存储每个模型参数对应的临时状态,例如momentum
。param_groups
:一个列表,其中的每一个元素都是一个键为字符串的字典。每一个元素对应了一个param_group
。
看到这里,可能你仍然不明白 param_groups
是什么。既然它是复数形式,说明它是由一个个 param_group
组成的,每一个 param_group
的类型是 Dict[str, Any]
,因此 param_groups
的类型就是 List[Dict[str, Any]]
。
那么什么是 param_group
呢?我们知道Transformer模型通常由多个layer堆叠而成,绝大部分情况下,整个模型的训练会采用同一个学习率。但某些特殊场景下,我们可能希望不同的layer使用不同的学习率,此时就会涉及到一个个 param_group
了。
- 对于前者,形参
params
的类型为Iterable[Tensor]
,因为整个模型会共享同一套优化器参数,所以只需要指定全局优化器参数defaults
即可,此时param_groups
是一个长度为1的列表。 - 而对于后者,
params
的类型为Iterable[Dict[str, Any]]
,每一个字典包含了layer的参数和对应的局部优化器参数。注意此时依然可以指定全局优化器参数,例如我们希望不同的layer使用不同的学习率,但希望所有的layer都使用同一个动量。此时param_groups
是一个长度大于1的列表。
我们来看更具体的例子。
3.1.1 全局设置情形
class MLP(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(2, 3, bias=False)self.fc2 = nn.Linear(3, 3, bias=False)self.fc3 = nn.Linear(3, 1, bias=False)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return xmodel = MLP()# =============================
# 情况 1: 所有层使用同一套优化器参数
# =============================
# 此时,params 是 Iterable[Tensor],直接传递模型的所有参数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
print(optimizer.param_groups)
输出(做了简化处理):
[{'params': [fc1_tensor, fc2_tensor, fc3_tensor],'lr': 0.01,'momentum': 0.9,'dampening': 0,'weight_decay': 0,'nesterov': False,'maximize': False,'foreach': None,'differentiable': False,'fused': None}
]
可以看到此时 param_groups
是一个长度为1的列表。字典中的 params
对应了模型的参数,因为没有设置 bias
,所以总共有三个张量。字典中剩余的键对应了优化器的超参数。
3.1.2 局部设置情形
# ===============================
# 情况 2: 不同的层使用不同的优化器参数
# ===============================
# 此时,params 是 Iterable[Dict[str, Any]],可以为不同的层设置不同的学习率
optimizer = optim.SGD([{'params': model.fc1.parameters(), 'lr': 0.001}, # 第1层学习率为 0.001{'params': model.fc2.parameters(), 'lr': 0.01}, # 第2层学习率为 0.01{'params': model.fc3.parameters(), 'lr': 0.1} # 第3层学习率为 0.1
])print(optimizer.param_groups)
输出(做了简化处理):
[{'params': [fc1_tensor],'lr': 0.001,'momentum': 0,'dampening': 0,'weight_decay': 0,'nesterov': False,'maximize': False,'foreach': None,'differentiable': False,'fused': None},{'params': [fc2_tensor],'lr': 0.01,'momentum': 0,'dampening': 0,'weight_decay': 0,'nesterov': False,'maximize': False,'foreach': None,'differentiable': False,'fused': None},{'params': [fc3_tensor],'lr': 0.1,'momentum': 0,'dampening': 0,'weight_decay': 0,'nesterov': False,'maximize': False,'foreach': None,'differentiable': False,'fused': None}
]
此时 param_groups
是一个长度为3的列表。列表中的每一个字典就是一个 param_group
,存储了相应的layer参数和局部优化器参数。
⚠️ 这里使用“局部优化器参数”这一术语并不准确。事实上,“局部”和“全局”的概念仅在优化器实例化之前存在。一旦优化器实例化,全局参数将会逐步写入到每一个
param_group
中,例如momentum
和dampening
等未显式定义的参数,实质上属于全局优化器参数,并会通过add_param_group
方法自动地写入到每一个param_group
中。因此,在优化器实例化后,每个param_group
都拥有一套完整且独立的参数配置。
到目前为止,我们可以做一个简单总结。param_groups
是一个元素为字典的列表。当传入的 params
为由Tensor构成的迭代器时,此时 param_groups
的长度为1,即只含有一个 param_group
。当传入的 params
为由字典构成的迭代器时,此时 param_groups
的长度为 len(params)
。
param_groups
中的所有字典的键完全相同,均形如:
param_group = {'params': [tensor_1, tensor_2, ...], # 待优化的模型参数**defaults # 全局优化器参数
}
但所有字典的值却不尽相同。
3.1.3 覆盖测试
之前我们只考虑了「仅全局」和「仅局部」的情形,如果我们手动设置全局优化器参数,并且它和某些 param_group
中的局部优化器参数冲突了,那么这个全局的会覆盖掉局部的吗?
optimizer = optim.SGD([{'params': model.fc1.parameters(), 'lr': 0.001, 'momentum': 0.3},{'params': model.fc2.parameters(), 'lr': 0.01},{'params': model.fc3.parameters(), 'lr': 0.1, 'nesterov': True}
], momentum=0.9, nesterov=False)print(optimizer.param_groups)
输出:
[{'params': [fc1_tensor],'lr': 0.001,'momentum': 0.3,'dampening': 0,'weight_decay': 0,'nesterov': False,'maximize': False,'foreach': None,'differentiable': False,'fused': None},{'params': [fc2_tensor],'lr': 0.01,'momentum': 0.9,'dampening': 0,'weight_decay': 0,'nesterov': False,'maximize': False,'foreach': None,'differentiable': False,'fused': None},{'params': [fc3_tensor],'lr': 0.1,'nesterov': True,'momentum': 0.9,'dampening': 0,'weight_decay': 0,'maximize': False,'foreach': None,'differentiable': False,'fused': None}
]
由此可以得出结论:全局优化器参数不会覆盖掉局部优化器参数。
3.1.4 逐行讲解
现在我们已经对 params
、defaults
、param_groups
这三个变量有了足够的了解(state
会放在下文讲解),接下来我们逐行剖析构造方法。
def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:# 将defaults设置成属性以方便在后续的add_param_group方法中使用self.defaults = defaults# params必须是关于tensor或dict的可迭代对象if isinstance(params, torch.Tensor):raise TypeError("params argument given to the optimizer should be ""an iterable of Tensors or dicts, but got " + torch.typename(params))# 初始化两大重要属性:state和param_groupsself.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)self.param_groups: List[Dict[str, Any]] = []# 将迭代器转化成列表,此时要么是List[Tensor]要么是List[Dict]param_groups = list(params)# 必须非空if len(param_groups) == 0:raise ValueError("optimizer got an empty parameter list")# 如果不是List[Dict],说明进行的是全局设置,此时param_groups是List[Tensor]# 类型,代表模型的所有参数。然后将其转变成List[Dict]类型,以达到格式统一的目的if not isinstance(param_groups[0], dict):param_groups = [{"params": param_groups}]# 将每一个param_group经过相关处理后添加到self.param_groups中# 这里也会将全局优化器参数defaults注入到每一个param_group中for param_group in param_groups:self.add_param_group(param_group)
通常来讲,params
会接收 model.parameters()
作为输入,如果不是,那么 params
必须是由字典构成的列表,且每一个字典必须含有 params
这个键,对应的值是模型的部分参数。
绝大多数情况下,我们认为下式成立:
model.parameters() = ⋃ i = 1 k param_groups [ i ] [ " params " ] \text{model.parameters()}=\bigcup_{i=1}^k \text{param\_groups}[i]["\text{params}"] model.parameters()=i=1⋃kparam_groups[i]["params"]
k k k 是 param_group
的个数,且 param_groups[i]["params"]
两两互不相交。
如果涉及到冻结模型的一部分参数,仅训练剩余的参数,那么上式就不再成立了。
3.2 add_param_group
有了3.1小节的基础后,这里直接逐行讲解源码。
def add_param_group(self, param_group: Dict[str, Any]) -> None:r"""Add a param group to the :class:`Optimizer`'s `param_groups`.This can be useful when fine tuning a pre-trained network as frozen layers can be madetrainable and added to the :class:`Optimizer` as training progresses.Args:param_group (dict): Specifies what Tensors should be optimized along with groupspecific optimization options."""# 确保传入的param_group一定是一个字典if not isinstance(param_group, dict):raise TypeError(f"param_group must be a dict, but got {type(param_group)}")# 获取该group中的模型参数部分,然后将其转变为List[Tensor]类型params = param_group["params"]if isinstance(params, torch.Tensor):param_group["params"] = [params]elif isinstance(params, set):raise TypeError("optimizer parameters need to be organized in ordered collections, but ""the ordering of tensors in sets will change between runs. Please use a list instead.")else:param_group["params"] = list(params)# 参数检查for param in param_group["params"]:# 确保所有的参数必须都是Tensor,否则无法优化if not isinstance(param, torch.Tensor):raise TypeError("optimizer can only optimize Tensors, ""but one of the params is " + torch.typename(param))# 确保所有的参数都是叶子节点if not (param.is_leaf or param.retains_grad):raise ValueError("can't optimize a non-leaf Tensor")# 将全局优化器参数注入到当前的group中,setdefault保证了这一过程不会覆盖掉局部优化器参数for name, default in self.defaults.items():param_group.setdefault(name, default)# 检查是否存在重复的参数# 目前出现重复的参数并不会报错params = param_group["params"]if len(params) != len(set(params)):warnings.warn("optimizer contains a parameter group with duplicate parameters; ""in future, this will cause an error; ""see github.com/pytorch/pytorch/issues/40967 for more information",stacklevel=3,)# 判断当前的group中是否有参数和已经添加过的group中的参数重复# 两个集合交集为空,isdisjoint()返回Trueparam_set: Set[torch.Tensor] = set()for group in self.param_groups:param_set.update(set(group["params"]))if not param_set.isdisjoint(set(param_group["params"])):raise ValueError("some parameters appear in more than one parameter group")# 将当前的group添加到self.param_groups中self.param_groups.append(param_group)
add_param_group
源码看似复杂,但归根结底也就那么几行代码是真正起到作用的,这里给出一个简化版本:
def add_param_group(self, param_group: Dict[str, Any]) -> None:params = param_group["params"]param_group["params"] = [params] if isinstance(params, torch.Tensor) else list(params)for name, default in self.defaults.items():param_group.setdefault(name, default)self.param_groups.append(param_group)
3.3 step
Optimizer
类并没有实现 step()
方法:
def step(self) -> None:"""Performs a single optimization step (parameter update)."""raise NotImplementedError
自定义优化器时,需要继承 Optimizer
类,并实现该方法,否则会报错。
因为 step()
不返回任何值,所以需要实现模型参数的原地更新。
3.4 zero_grad
如下是简化版的源码:
def zero_grad(self, set_to_none: bool = True) -> None:r"""Resets the gradients of all optimized :class:`torch.Tensor` s.Args:set_to_none (bool): instead of setting to zero, set the grads to None.This will in general have lower memory footprint, and can modestly improve performance.However, it changes certain behaviors. For example:1. When the user tries to access a gradient and perform manual ops on it,a None attribute or a Tensor full of 0s will behave differently.2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ sare guaranteed to be None for params that did not receive a gradient.3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None(in one case it does the step with a gradient of 0 and in the other it skipsthe step altogether)."""# 遍历每一个group,然后遍历group中的每一个参数for group in self.param_groups:for p in group["params"]:# 如果p的梯度不为None,说明需要清空if p.grad is not None:if set_to_none:p.grad = Noneelse:# grad_fn 表示该张量是由某个操作生成的# 因此将该张量从计算图中分离出来if p.grad.grad_fn is not None:p.grad.detach_()else:# 不在计算图中,原地关闭梯度以防止后续追踪p.grad.requires_grad_(False)# 清零梯度p.grad.zero_()
从注释可以看出,当梯度被设置为 None
时,梯度张量会被释放,从而减少内存占用。而如果梯度被清零(即将其所有元素设置为 0),梯度张量的内存仍然会被保留。
由于 set_to_none
默认为 True
,因此 zero_grad
源码可以进一步简化:
def zero_grad(self) -> None:for group in self.param_groups:for p in group["params"]:if p.grad is not None:p.grad = None
3.5 self.state
Optimizer
中有两大重要属性:state
、param_groups
。先前我们已经了解了 param_groups
,现在来看 state
。
我们已经知道构造方法中会通过多次调用 add_param_group
来初始化 self.param_groups
,但截止目前,似乎并没有方法能够初始化 self.state
,那它是怎么初始化的?以及它到底“长什么样”呢?
self.state
用来存储与模型参数相关的临时状态。对于SGD with momentum而言,每个参数都需要维护一个动量。对于Adam而言,每个参数不仅需要维护一个动量(一阶矩),还需要维护一个平方梯度(二阶矩)。很显然,对于一个待优化的Tensor,它的临时状态和它的形状是相同的,并且对于该Tensor,可能有多个临时状态需要维护,每个临时状态都有一个自己的名字,由此推测 self.state
的类型应当是 Dict[Tensor, Dict[str, Tensor]]
,这与源码中声明的相同:
self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
即 self.state
的键是一个Tensor,值是一个字典。字典存储了优化该Tensor的一些临时状态,键是临时状态的名称,值是相应的状态。
以SGD优化器为例,进行一次单步更新,然后查看它的 state
属性:
import torch
import torch.nn as nn
import torch.optim as optimclass SimpleMLP(nn.Module):def __init__(self):super(SimpleMLP, self).__init__()self.fc1 = nn.Linear(2, 3, bias=False)self.fc2 = nn.Linear(3, 1, bias=False)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleMLP()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)inputs = torch.randn(2)
target = torch.randn(1)output = model(inputs)
loss = criterion(output, target)loss.backward()
optimizer.step()print(optimizer.state)
输出:
defaultdict(<class 'dict'>, {Parameter containing:tensor([[-0.4128, -0.6015],[-0.3628, 0.6162],[ 0.4367, 0.3409]], requires_grad=True): {'momentum_buffer': tensor([[ 0.0000, 0.0000],[-0.0052, 0.0581],[-0.0130, 0.1446]])},Parameter containing:tensor([[0.5394, 0.1077, 0.2748]], requires_grad=True): {'momentum_buffer': tensor([[0.0000, 0.3398, 0.1585]])}
})
可以得知,对于SGD优化器而言,每个待优化的Tensor只需要维护一个临时状态:momentum_buffer
,即当前的动量。且临时状态的形状与待优化的Tensor相同。
将SGD换成Adam,再来看看结果:
defaultdict(<class 'dict'>, {Parameter containing:tensor([[ 0.4954, -0.0392],[ 0.0778, -0.5769],[-0.3332, -0.0659]], requires_grad=True): {'step': tensor(1.),'exp_avg': tensor([[-0.1632, 0.0252],[-0.0513, 0.0079],[ 0.0000, 0.0000]]),'exp_avg_sq': tensor([[2.6648e-03, 6.3460e-05],[2.6274e-04, 6.2572e-06],[0.0000e+00, 0.0000e+00]])},Parameter containing:tensor([[-0.4492, -0.1479, 0.4789]], requires_grad=True): {'step': tensor(1.),'exp_avg': tensor([[0.1821, 0.0577, 0.0000]]),'exp_avg_sq': tensor([[0.0033, 0.0003, 0.0000]])}
})
此时每个Tensor需要维护三个临时状态:step
、exp_avg
和 exp_avg_sq
。step
是当前更新的步数,exp_avg
就是SGD中的动量(不完全相同),exp_avg_sq
是平方梯度。
知道了 self.state
长什么样后,我们需要了解一下它是如何初始化的。
事实上 Optimizer
类并没有实现 self.state
的初始化,因为它的初始化是在 step()
中完成的,所以我们需要关注 Optimizer
的子类,这里以SGD为例。
class SGD(Optimizer):def step(self):for group in self.param_groups:# 用来存储模型参数,梯度,动量params: List[Tensor] = []grads: List[Tensor] = []momentum_buffer_list: List[Optional[Tensor]] = []# 填充params、grads、momentum_buffer_listself._init_group(group, params, grads, momentum_buffer_list)# 执行sgd优化算法sgd(params, grads, momentum_buffer_list, ...)if group["momentum"] != 0:for p, momentum_buffer in zip(params, momentum_buffer_list):# 获取Tensorstate = self.state[p]# 更新Tensor的临时状态state["momentum_buffer"] = momentum_bufferdef _init_group(self, group, params, grads, momentum_buffer_list):for p in group["params"]:if p.grad is not None:params.append(p)grads.append(p.grad)if group["momentum"] != 0:# 因为self.state是defaultdict,所以初始时state会自动创建为一个字典state = self.state[p]# 因为初始时state没有momentum_buffer这个键# 所以momentum_buffer_list的初始值为[None, None, None, ...]momentum_buffer_list.append(state.get("momentum_buffer"))
每一步更新,_init_group
会在SGD算法执行前被调用,它用来获取模型所有待更新的参数,对应的已经计算的梯度,以及对应的上一时刻的动量。self.state
会在 _init_group
中进行初始化。我们可以通过在 self._init_group
和 sgd()
这两个语句之间加入以下代码来查看相应的信息:
print(self.state)
print(momentum_buffer_list)
输出:
defaultdict(<class 'dict'>, {Parameter containing:tensor([[ 0.4679, -0.6531],[-0.4707, -0.2854],[ 0.6846, -0.6576]], requires_grad=True): {},Parameter containing:tensor([[-0.4362, 0.4155, 0.1798]], requires_grad=True): {}
})[None, None]
这说明初始时,每一个Tensor对应的临时状态为空字典,且 momentum_buffer_list
的初始值为 [None, None, ...]
。
3.6 state_dict
state_dict
的作用是保存优化器当前的状态,以便在之后的训练中恢复或继续使用。
显然 state_dict
应当保存优化器的两大重要属性:
state_dict = {"state": state,"param_groups": param_groups,
}
但根据之前的分析,state
中的键会涉及到模型参数,此外,param_groups
中的 params
也会涉及到模型参数,如果就这样直接保存,相当于我们在保存了优化器的同时还保存了两份模型参数(state一份,param_groups一份),这显然是不可行的。
一种直观的想法是将这些模型参数映射成唯一的数字ID,而这些ID是几乎不占空间的。假设 param_groups
含有 k k k 个 param_group
,那么我们可以从第一个group开始,从0开始从前往后依次编号直至第 k k k 个group。
当然,我们还需要对 state
进行编号,由于 state
和 param_groups
中的参数不一定一一对应(因为 param_groups
可能会含有重复的参数,但 state
不会,所以二者长度不一定相等,具体见 add_param_group
源码),因此我们不能从0开始从前往后依次编号。这启发我们可以构造一个参数的内存地址到ID的映射 mapping
,这样在编号 state
的时候,我们就可以通过 mapping[id(tensor)]
来获取模型参数的ID了。
源码解析(已做简化):
def state_dict(self) -> Dict[str, Any]:# 构建模型参数地址到ID的映射param_mappings: Dict[int, int] = {}start_index = 0def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:nonlocal start_index# 处理优化器超参数部分packed = {k: v for k, v in group.items() if k != "params"}# 更新映射param_mappings.update({id(p): ifor i, p in enumerate(group["params"], start_index)if id(p) not in param_mappings})# 处理模型参数部分,将具体的参数映射为IDpacked["params"] = [param_mappings[id(p)] for p in group["params"]]start_index += len(packed["params"])return packed# 将所有group中的所有模型参数映射为IDparam_groups = [pack_group(g) for g in self.param_groups]# 将state中的所有模型参数映射为IDpacked_state = {param_mappings[id(k)]: vfor k, v in self.state.items()}state_dict = {"state": packed_state,"param_groups": param_groups,}return state_dict
我们可以通过以下代码来查看 state_dict
的样子:
import torch
import torch.nn as nn
import torch.optim as optim
import itertoolsclass MLP(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(2, 3, bias=False)self.fc2 = nn.Linear(3, 3, bias=False)self.fc3 = nn.Linear(3, 1, bias=False)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return xmodel = MLP()optimizer = optim.SGD([{"params": itertools.chain(model.fc1.parameters(), model.fc2.parameters()), "lr": 0.01},{"params": model.fc3.parameters(), "lr": 0.1},
], momentum=0.9)x = torch.randn(1, 2)
y = torch.randn(1, 1)criterion = nn.MSELoss()output = model(x)
loss = criterion(output, y)optimizer.zero_grad()
loss.backward()
optimizer.step()print(optimizer.state_dict().keys())print(optimizer.state_dict()['state'])
print(optimizer.state_dict()['param_groups'])
输出:
dict_keys(['state', 'param_groups'])# state部分
{0: {'momentum_buffer': tensor([[-0.0270, 0.1874],[ 0.0000, 0.0000],[ 0.0000, 0.0000]])},1: {'momentum_buffer': tensor([[ 0.0000, 0.0000, 0.0000],[-0.2232, 0.0000, 0.0000],[ 0.0000, 0.0000, 0.0000]])},2: {'momentum_buffer': tensor([[ 0.0000, -0.3215, 0.0000]])}
}# param_groups部分
[{'lr': 0.01,'momentum': 0.9,'dampening': 0,'weight_decay': 0,'nesterov': False,'maximize': False,'foreach': None,'differentiable': False,'fused': None,'params': [0, 1]},{'lr': 0.1,'momentum': 0.9,'dampening': 0,'weight_decay': 0,'nesterov': False,'maximize': False,'foreach': None,'differentiable': False,'fused': None,'params': [2]}
]
可以看到 state
和 param_groups
中的模型参数全被映射成了ID。
3.7 load_state_dict
因为之前的ID映射是根据 param_groups
构造的,所以在load的时候,我们也要根据 param_groups
去建立一一对应关系,此时需要构造ID到模型参数的映射。
这里有一个细节,在还原 params
的时候,我们可以直接通过ID进行还原,但是在还原 state
的时候,我们要确保各个临时状态和模型参数是处于同一设备上。
源码解析(已做简化):
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:# 获取当前优化器的param_groups和之前保存的param_groupsgroups = self.param_groupssaved_groups = deepcopy(state_dict["param_groups"])# 确保group的数量相等if len(groups) != len(saved_groups):raise ValueError("loaded state dict has a different number of " "parameter groups")# 确保每个group中的模型参数个数相等param_lens = (len(g["params"]) for g in groups)saved_lens = (len(g["params"]) for g in saved_groups)if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):raise ValueError("loaded state dict contains a parameter group ""that doesn't match the size of optimizer's group")# 之前是根据enumerate构造的正向映射# 现在直接用zip构造反向映射id_map = dict(zip(chain.from_iterable(g["params"] for g in saved_groups),chain.from_iterable(g["params"] for g in groups),))# 用于递归地将Tensor对应的临时状态移动到和Tensor相同的设备上,并确保数据类型相同def _cast(param, value, key=None):if isinstance(value, torch.Tensor):if key == 'step':return valueelse:if param.is_floating_point():return value.to(dtype=param.dtype, device=param.device)else:# 例如模型是一个量化模型,此时不必转化value的数据类型return value.to(device=param.device)elif isinstance(value, dict):return {k: _cast(param, v, key=k)for k, v in value.items()}elif isinstance(value, Iterable):return type(value)(_cast(param, v) for v in value)else:return value# 还原state# 注意要将临时状态转移到和param相同的设备上,有些时候还需要确保数据类型相同state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)for k, v in state_dict["state"].items():param = id_map[k]state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])# 还原param_groups,只有params需要修改def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:new_group["params"] = group["params"]return new_groupparam_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]# 更新self.__dict__.update({"state": state, "param_groups": param_groups})
4. SGD Optimizer
在了解 Optimizer
源码后,我们开看SGD优化器是如何实现的。
构造函数(简化版):
class SGD(Optimizer):def __init__(self,params,lr: float = 1e-3,momentum: float = 0,dampening: float = 0,weight_decay: float = 0,nesterov=False,):# 将独属于SGD优化器的超参数打包成defaultsdefaults = dict(lr=lr,momentum=momentum,dampening=dampening,weight_decay=weight_decay,nesterov=nesterov,)# 调用父类的构造函数初始化super().__init__(params, defaults)
无非就做了两件事:
- 将独属于该优化器的超参数打包成
defaults
- 调用父类的构造函数进行初始化(不然没有
state
和param_groups
这两个属性)
step
和 _init_group
在3.5节中已经介绍过,这里关注sgd函数如何实现。
⚠️ 对SGD算法不熟悉的读者可以看这篇博客:深入解析SGD、Momentum与Nesterov:优化算法的对比与应用述
def sgd(params: List[Tensor],grads: List[Tensor],momentum_buffer_list: List[Optional[Tensor]],weight_decay: float,momentum: float,lr: float,dampening: float,nesterov: bool,
):for i, param in enumerate(params):grad = grads[i]# 权重衰减if weight_decay != 0:grad = grad.add(param, alpha=weight_decay)if momentum != 0:buf = momentum_buffer_list[i] # 存储的是每个参数上一时刻的动量if buf is None:buf = torch.clone(grad).detach() # 初始时动量就是梯度,因为m_0 = 0momentum_buffer_list[i] = bufelse:# buf = momentum * buf + (1 - dampening) * gradbuf.mul_(momentum).add_(grad, alpha=1 - dampening)if nesterov:grad = grad.add(buf, alpha=momentum)else:grad = buf# 更新权重:param = param - lr * gradparam.add_(grad, alpha=-lr)
5. 极简版Optimizer源码
截止目前,我们可以对 Optimizer
基类的源码进行汇总,给出一个极简版的实现:
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import Any, Dict, Iterableimport torchclass Optimizer:def __init__(self, params, defaults: Dict[str, Any]) -> None:self.defaults = defaultsself.state = defaultdict(dict)self.param_groups = []param_groups = list(params)if not isinstance(param_groups[0], dict):param_groups = [{"params": param_groups}]for param_group in param_groups:self.add_param_group(param_group)def add_param_group(self, param_group: Dict[str, Any]) -> None:params = param_group["params"]param_group["params"] = [params] if isinstance(params, torch.Tensor) else list(params)for name, default in self.defaults.items():param_group.setdefault(name, default)self.param_groups.append(param_group)def step(self) -> None:raise NotImplementedErrordef zero_grad(self) -> None:for group in self.param_groups:for p in group["params"]:if p.grad is not None:p.grad = Nonedef state_dict(self) -> Dict[str, Any]:param_mappings = {}start_index = 0def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:nonlocal start_indexpacked = {k: v for k, v in group.items() if k != "params"}param_mappings.update({id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings})packed["params"] = [param_mappings[id(p)] for p in group["params"]]start_index += len(packed["params"])return packedparam_groups = [pack_group(g) for g in self.param_groups]packed_state = {param_mappings[id(k)]: v for k, v in self.state.items()}return {"state": packed_state, "param_groups": param_groups}def load_state_dict(self, state_dict: Dict[str, Any]) -> None:groups = self.param_groupssaved_groups = deepcopy(state_dict["param_groups"])id_map = dict(zip(chain.from_iterable(g["params"] for g in saved_groups),chain.from_iterable(g["params"] for g in groups),))def _cast(param, value, key=None):if isinstance(value, torch.Tensor):if key == 'step':return valueelse:return value.to(dtype=param.dtype, device=param.device) if param.is_floating_point() else value.to(device=param.device)elif isinstance(value, dict):return {k: _cast(param, v, key=k) for k, v in value.items()}elif isinstance(value, Iterable):return type(value)(_cast(param, v) for v in value)else:return valuestate = defaultdict(dict)for k, v in state_dict["state"].items():param = id_map[k]state[param] = _cast(param, v, param_id=k)def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:new_group["params"] = group["params"]return new_groupparam_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]self.__dict__.update({"state": state, "param_groups": param_groups})
6. 自定义你的Optimizer
以SGD为例,我们来实现一个只有学习率的极简版优化器。
步骤如下:
- 继承
Optimizer
,声明构造函数,构造函数的形参必须含有params
,随后的一系列形参都是和该优化器相关的超参数。 - 在构造函数中将相关超参数打包成
defaults
,然后调用父类的构造函数。 - 重写
step
方法,并用@torch.no_grad()
装饰。 - 在
step
中遍历self.param_groups
,每一次遍历,声明params
、grads
列表,然后用group
的数据进行填充。如果涉及到临时状态,还需要额外声明和临时状态相关的列表。
class SimpleSGD(Optimizer):def __init__(self, params, lr=0.01):assert lr > 0.0defaults = dict(lr=lr)super().__init__(params, defaults)@torch.no_grad()def step(self):for group in self.param_groups:params = []grads = []for p in group['params']:if p.grad is not None:params.append(p)grads.append(p.grad)lr = group['lr']for param, grad in zip(params, grads):param.add_(grad, alpha=-lr)
仅设置学习率时,它所产生的结果和官方的SGD相同,读者可自行验证。
Ref
[1] https://blog.csdn.net/zzxxxaa1/article/details/121144570?spm=1001.2014.3001.5502
[2] https://www.hjhgjghhg.com/archives/119/
[3] https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
相关文章:
PyTorch源码系列(一)——Optimizer源码详解
目录 1. Optimizer类2. Optimizer概览3. 源码解析3.1 构造方法3.1.1 全局设置情形3.1.2 局部设置情形3.1.3 覆盖测试3.1.4 逐行讲解 3.2 add_param_group3.3 step3.4 zero_grad3.5 self.state3.6 state_dict3.7 load_state_dict 4. SGD Optimizer5. 极简版Optimizer源码6. 自定…...
Java - LeetCode面试经典150题(三)
区间 228. 汇总区间 题目 给定一个 无重复元素 的 有序 整数数组 nums 。 返回 恰好覆盖数组中所有数字 的 最小有序 区间范围列表 。也就是说,nums 的每个元素都恰好被某个区间范围所覆盖,并且不存在属于某个范围但不属于 nums 的数字 x 。 列表中…...
基于SpringBoot+Vue+MySQL的民宿预订平台
系统展示 用户前台界面 管理员后台界面 商家后台界面 系统背景 随着旅游业的蓬勃发展,民宿作为一种独特的住宿方式,受到了越来越多游客的青睐。然而,传统的民宿预定方式往往存在信息不对称、效率低下等问题,难以满足游客的个性化需…...
Hadoop krb5.conf 配置详解
krb5.conf文件是Kerberos认证系统中的一个关键配置文件,它包含了Kerberos的配置信息,如KDC(Key Distribution Centers)和Kerberos相关域的管理员服务器位置、当前域和Kerberos应用的默认设置、以及主机名与Kerberos域的映射等。以…...
工程师 - DNS请求过程
DNS(Domain Name System,域名系统)是互联网的重要基础设施之一,其主要功能是将人们容易记忆的域名(例如 www.example.com)转换为计算机能识别的IP地址(例如 192.0.2.1),类…...
Solidity智能合约中的事件和日志
1. Solidity 中的事件和日志概述 1.1 什么是事件? 在 Solidity 中,事件(Event)是一种允许智能合约与外部世界进行通信的机制。通过触发事件,可以记录合约执行中的关键操作,并将这些操作发送到链上。事件的…...
第四十一篇-Docker安装Neo4j
创建目录 mkdir /opt/neo4j-data创建 docker run \ -d --name neo4j \ -p 7474:7474 -p 7687:7687 \ -v /opt/neo4j-data/data:/data \ -v /opt/neo4j-data/logs:/logs \ -v /opt/neo4j-data//conf:/var/lib/neo4j/conf \ -v /opt/neo4j-data/plugins:/plugins \ --env NEO4J…...
数电基础(组合逻辑电路+Proteus)
1.组合逻辑电路 1.1组合逻辑电路的分析 1.1.1组合逻辑电路的定义 组合逻辑电路的定义 (1)对于一个逻辑电路,其输出状态在任何时刻只取决于同一时刻的输入状态,而与电路的原来状态无关,这种电路被定义为组合逻辑电路…...
自给自足:手搓了一个睡眠监测仪,用着怎么样?
很久不分享手搓党作品拉! 今天分享一个“基于毫米波雷达的睡眠监测仪”作品! 用Air700E开发板毫米波雷达,手搓一个开箱即用的睡眠监测仪,不花冤枉钱! 来仔细瞧瞧! 一、项目原理及硬件制作 毫米波是指频率…...
Miniforge详细安装教程(macOs和Windows)
(注:主要是解决商业应用anaconda收费问题,这是轻量级的代替,个人完全可以使用anaconda和miniconda) Miniforge 是一个轻量级的包管理器,类似于 Anaconda 和 Miniconda。它主要用于安装基于 conda 的 Python 环境,专注于…...
HDFS Shell作业1
1.在HDFS上建立/user/stu/自己学号,和/user/stu/input目录。 命令: hdfs dfs -mkdir -p /user/stu/22 hdfs dfs -mkdir /user/stu/input 2.用两种不同的方法上传albums.csv至HDFS的学号目录和input目录中。 命令: hdfs dfs -put par…...
工业交换机一键重启的好处
在当今高度自动化和智能化的工业环境中,工业交换机作为网络系统中至关重要的一环,其稳定性和可靠性直接影响到整个生产过程的顺利进行。为了更好地维护这些设备的健康运行,一键重启功能应运而生,并呈现出诸多显著的好处。 首先&am…...
滚雪球学Oracle[4.2讲]:PL/SQL基础语法
全文目录: 前言一、PL/SQL基础语法1.1 变量声明变量声明示例: 二、记录类型与集合类型的使用2.1 记录类型记录类型的定义与使用 2.2 集合类型 三、PL/SQL表与关联数组3.1 PL/SQL表(嵌套表)嵌套表的定义与使用 3.2 关联数组关联数组…...
springboot系列--web相关知识探索二
一、映射 指的是与请求处理方法关联的URL路径,通过在Spring MVC的控制器类(使用RestController注解修饰的类)上使用注解(如 RequestMapping、GetMapping)来指定请求映射路径,可以将不同的HTTP请求映射到相应…...
Oracle 12c在Windows环境下安装
适合初学者使用的Oracle 12c在Windows环境下安装步骤、参数配置、常见问题及参数调优的详细补充说明。 一、Oracle 12c安装步骤 1. 准备工作 在安装Oracle 12c之前,确保你的系统满足以下要求: 操作系统:Oracle 12c支持的Windows版本包括Wi…...
Stable Diffusion绘画 | 来训练属于自己的模型:打标处理与优化
上一篇完成的打标工作,是为了获取提示词,让AI认识和学习图片的特征。 因此,合适、恰当、无误的提示词,对最终模型效果是相当重要的。 Tag 如何优化 通过软件自动生成的 Tag 只是起到快速建立大体架构的作用,里面会涉…...
【论文笔记】Visual Instruction Tuning
🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。 基本信息 标题: Visual Instruction Tunin…...
ubuntu 设置静态IP
一、 ip addresssudo nano /etc/netplan/50-cloud-init.yaml 修改前: 修改后: # This file is generated from information provided by the datasource. Changes # to it will not persist across an instance reboot. To disable cloud-inits # ne…...
Java 每日一刊(第19期):泛型
文章目录 前言1. 泛型概述1.1 不使用泛型 vs 使用泛型1.2 泛型的作用 2. 泛型的基本语法2.1 定义带类型参数的泛型类2.2 使用泛型类2.3 泛型方法 3. 泛型类型推断与钻石操作符3.1 类型推断3.2 钻石操作符 4. 通配符的使用4.1 无界通配符 <?>4.2 上界通配符 <? exten…...
windows下安装rabbitMQ并开通管理界面和允许远程访问
如题,在windows下安装一个rabbitMQ server;然后用浏览器访问其管理界面;由于rabbitMQ的默认账号guest默认只能本机访问,因此需要设置允许其他机器远程访问。这跟mysql的思路很像,默认只能本地访问,要远程访…...
深度剖析音频剪辑免费工具的特色与优势
是热爱生活的伙伴或者想要记录美好声音的普通用户,都可能会需要对音频进行剪辑处理。而幸运的是,现在有许多优秀的音频剪辑软件提供了免费版本,让我们能够轻松地施展音频剪辑的魔法。接下来,就让我们一同深入了解这些音频剪辑免费…...
Oracle中TRUNC()函数详解
文章目录 前言一、TRUNC函数的语法二、主要用途三、测试用例总结 前言 在Oracle中,TRUNC函数用于截取或截断日期、时间或数值表达式的部分。它返回一个日期、时间或数值的截断版本,根据提供的格式进行截取。 一、TRUNC函数的语法 TRUNC(date) TRUNC(d…...
【Spring Boot 入门一】构建你的第一个Spring Boot应用
一、引言 在当今的软件开发领域,Java一直占据着重要的地位。而Spring Boot作为Spring框架的延伸,为Java开发者提供了一种更加便捷、高效的开发方式。它简化了Spring应用的搭建和配置过程,让开发者能够专注于业务逻辑的实现。无论是构建小型的…...
PPT 快捷键使用、技巧
前言: 本文操作是以office 2021为基础的,仅供参考;不同版本office 的 ppt 快捷键 以及对应功能会有差异,需要实践出真知。 shift 移动 水平/垂直 移动 ; shift 放大/缩小 等比例放大 缩小 ; 正圆 正…...
Web安全 - 文件上传漏洞(File Upload Vulnerability)
文章目录 OWASP 2023 TOP 10导图定义攻击场景1. 上传恶意脚本2. 目录遍历3. 覆盖现有文件4. 文件上传结合社会工程攻击 防御措施1. 文件类型验证2. 文件名限制3. 文件存储位置4. 文件权限设置5. 文件内容检测6. 访问控制7. 服务器配置 文件类型验证实现Hutool的FileTypeUtil使用…...
vue3中el-input在form表单按下回车刷新页面
摘要: 在input框中点击回车之后不是调用我写的回车事件,而是刷新页面! 如果表单中只有一个input 框则按下回车会直接关闭表单 所以导致刷新页面 再写一个input 表单 ,并设置style“display:none” <ElInput style"display…...
SQL Server中关于个性化需求批量删除表的做法
在实际开发中,我们常常会遇到需要批量删除表,且具有共同特征的情况,例如:找出表名中数字结尾的表之类的,本文我将以3中类似情况为例,来示范并解说此类需求如何完成: 第一种,批量删除…...
关于按键状态机解决Delay给程序带来的问题
问题产生 我在学习中断的过程中,使用EXTI15外部中断,在其中加入HAL_Delay();就会发生报错 错误地方 其它地方配置 问题原因 在中断服务例程(ISR)中使用 HAL_Delay() 会导致问题的原因是: 阻塞性: HAL_D…...
62.【C语言】浮点数的存储
目录 1.浮点数的类型 2.浮点数表示的范围 3.浮点数的特性 《计算机科学导论》的叙述 4.浮点数在内存中的存储 答案速查 分析 前置知识:浮点数的存储规则 推导单精度浮点数5.5在内存中的存储 验证 浮点数取出的分析 1.一般情况:E不全为0或不全为1 2.特殊情况:E全为0…...
GO网络编程(一):基础知识
1. 网络编程的基础概念 TCP/IP 协议栈 TCP/IP 是互联网通信的核心协议栈,分为以下四个层次: 应用层(Application Layer):为应用程序提供网络服务的协议,比如 HTTP、FTP、SMTP 等。传输层(Tra…...
网站域名禁止续费/在线生成个人网站
//定义PHP字符集header("content-type:text/html; charsetutf8");//连接数据库¥con mysql content(localhost,root,password);//定义插入data数据库字符集mysql_query(’set names utf8‘);//mysql_query( insert into data sheet name(name) values (ai…...
广东企业移动网站建设哪家好/聊城seo培训
iOS开发UI篇—UITableviewcell的性能问题 一、UITableviewcell的一些介绍 UITableView的每一行都是一个UITableViewCell,通过dataSource的 tableView:cellForRowAtIndexPath:方法来初始化每⼀行 UITableViewCell内部有个默认的子视图:contentView,contentView是UITableViewCell…...
hbuilder手机网站开发/最新热点新闻事件
我用 Linux 有些年头了。在这些年里我很有幸见证了开源的发展。各色各样的发行版在安装方面的努力,也是其中的一个比较独特的部分。以前,安装 Linux 是个最好让有技术的人来干的任务。现在,只要你会装软件,你就会安装 Linux。简单…...
怎么搜 织梦的网站/威海百度seo
三台虚拟机:node1是master节点,node2和node3是数据节点,也可选择将node1也作为数据节点 需要的基本设置1.时间同步(时间服务器),网络,hosts时间同步设置(时间服务器,ntp服务器)yum i…...
沈阳做网站客户多吗/百中搜优化软件靠谱吗
问题描述: Fibonacci数列的递推公式为:FnFn-1Fn-2,其中F1F21。 当n比较大时,Fn也非常大,现在我们想知道,Fn除以10007的余数是多少。 数据规模: 1 < n < 1,000,000 输入输出样例࿱…...
做网站后期都用什么软件/日本樱花免m38vcom费vps
分布式数据库一、分布式数据库分布式数据库由一组数据组成,这些数据物理上分布在计算机网络的不同结点(场地)上,逻辑上是属于同一个系统。每个结点可以执行局部应用,也能通过网络通信子系统执行全局应用。二、分布式数…...