当前位置: 首页 > news >正文

神经网络权重初始化

诸神缄默不语-个人CSDN博文目录

(如果只想看代码,请直接跳到“方法”一节,开头我介绍我的常用方法,后面介绍具体的各种方案)

神经网络通过多层神经元相互连接构成,而这些连接的强度就是通过权重(Weight)来表征的。权重是可训练的参数,意味着它们会在训练过程中根据反向传播算法自动调整,以最小化网络的损失函数。

每个神经元接收到的输入信号会与相应的权重相乘,然后所有这些乘积会被累加在一起,最后可能还会加上一个偏置项(Bias),形成该神经元的净输入。这个净输入随后会被送入激活函数,产生神经元的输出,进而传递给下一层的神经元。在这个过程中,权重决定了信号传递的强度和方向,是调整和控制网络学习过程的关键。

从数学角度看,权重可以被组织成矩阵或张量的形式,以支持高效的矩阵运算和便于处理来自网络上一层的所有输入及其对下一层的影响。训练开始时,权重通常会被初始化为小的随机值,这是为了打破对称性并允许网络学习。随着训练的进行,通过梯度下降算法等优化方法,权重会逐渐调整,以使得网络的预测输出尽可能接近真实标签。

总之,神经网络的权重是连接网络中各层之间的桥梁,它们的值决定了网络的行为和性能,通过训练不断优化这些权重,神经网络能够学习到复杂的数据表示和模式,完成各种复杂的任务。

在深度学习中,神经网络的权重初始化对模型的训练效率和最终性能有着至关重要的影响。适当的初始化方法可以帮助加速收敛,避免陷入局部最小值,同时也可以防止训练过程中的梯度消失或梯度爆炸问题。相反,不当的权重初始化可能导致模型训练效果不佳,甚至无法收敛。

文章目录

  • 权重初始化的必要性
  • 不认真对待的危害
  • 权重初始化方法
    • 1. 随机初始化
    • 2. Xavier/Glorot 初始化
    • 3. He/Kaiming 初始化
    • 4. SVD 初始化
  • 结论
  • 本文撰写过程中使用到的其他参考资料

权重初始化的必要性

  1. 加速收敛:合适的初始化方法能够使神经网络更快地收敛到较低的误差。
  2. 避免梯度问题:通过控制权重的初始范围,可以帮助避免训练过程中的梯度消失或爆炸问题。
  3. 影响泛化能力:初始化不仅影响训练速度和稳定性,也间接影响模型的泛化能力。

不认真对待的危害

  • 训练时间延长:不合适的初始化可能导致模型需要更长的时间来收敛。
  • 性能下降:极端情况下,不合适的初始化会导致模型无法从训练数据中学习有效的特征,从而严重影响模型性能。
  • 训练失败:在某些情况下,错误的初始化方法甚至会导致训练完全失败(例如,梯度消失或爆炸)。

权重初始化方法

我个人的习惯是在构建模型的时候直接对需要手写的权重进行初始化。权重用Xavier初始化,偏置直接初始化为全0向量,代码示例:

from torch.nn.init import xavier_normal_class MPBFNDecoder(nn.Module):def __init__(self):super(MPBFNDecoder,self).__init__()...self.Wf12=nn.Parameter(xavier_normal_(torch.empty(charge_num,ds)))self.Wf13=nn.Parameter(xavier_normal_(torch.empty(penalty_num,ds)))self.Wf23=nn.Parameter(xavier_normal_(torch.empty(penalty_num,ds)))self.b12=nn.Parameter(torch.zeros(charge_num,))self.b13=nn.Parameter(torch.zeros(penalty_num,))self.b23=nn.Parameter(torch.zeros(penalty_num,))

完整代码见https://github.com/PolarisRisingWar/LJP_Collection/blob/master/models/MPBFN/train_and_test.py

PyTorch内置的模型都已经自动写好了初始化函数,不需要手动设置。

以下有些代码示例是指定Linear中的权重进行初始化的。如果你们需要改成对特定参数进行初始化的话也好改,反正你们懂这个意思就行。

1. 随机初始化

Uniform 高斯分布初始化

  • 公式:权重 w ∼ N ( 0 , stdev 2 ) w \sim \mathcal{N}(0, \text{stdev}^2) wN(0,stdev2)

    • 其中, N ( 0 , stdev 2 ) \mathcal{N}(0, \text{stdev}^2) N(0,stdev2) 表示均值为0,标准差为 stdev \text{stdev} stdev 的高斯(正态)分布。
  • 概述:最简单的方法是从某个分布(通常是均匀分布或正态分布)中随机选取权重值。

  • 代码实例

import torch
import torch.nn as nn# 均匀分布初始化
def uniform_init(model):if isinstance(model, nn.Linear):nn.init.uniform_(model.weight, -1, 1)if model.bias is not None:nn.init.constant_(model.bias, 0)# 正态分布初始化
def normal_init(model):if isinstance(model, nn.Linear):nn.init.normal_(model.weight, mean=0, std=1)if model.bias is not None:nn.init.constant_(model.bias, 0)

2. Xavier/Glorot 初始化

公式:权重 w ∼ U ( − 6 n in + n out , 6 n in + n out ) w \sim \mathcal{U}(-\sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}, \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}) wU(nin+nout6 ,nin+nout6 )

  • 其中, U ( a , b ) \mathcal{U}(a, b) U(a,b) 表示均匀分布, n in n_{\text{in}} nin 是层输入的单元数, n out n_{\text{out}} nout 是层输出的单元数。

对梯度消失问题有优势。

  • 论文:(2010 PMLR) Understanding the difficulty of training deep feedforward neural networks
  • 原理:考虑到输入和输出的方差,目的是保持所有层的梯度大小大致相同。
  • 代码实例
def xavier_init(model):if isinstance(model, nn.Linear):nn.init.xavier_uniform_(model.weight)if model.bias is not None:nn.init.constant_(model.bias, 0)

3. He/Kaiming 初始化

  • 公式:权重 w ∼ N ( 0 , 2 n in ) w \sim \mathcal{N}(0, \frac{2}{n_{\text{in}}}) wN(0,nin2)
    • 其中, n in n_{\text{in}} nin 是层输入的单元数,假设权重初始化为均值为0,方差为 2 n in \frac{2}{n_{\text{in}}} nin2 的正态分布。

Kaiming Normal(也称为He Normal)初始化是由何凯明等人在2015年提出的一种权重初始化方法,旨在解决ReLU激活函数在深度神经网络中使用时的梯度消失或爆炸问题。这种方法考虑到了ReLU激活函数特性,特别是其非零区域的分布特点,从而提出通过调整初始化权重的方差来保持信号在前向传播和反向传播过程中的稳定。

  • 论文:(2015 ICCV) Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
  • 原理:Kaiming Normal 初始化的核心思想是根据网络层的输入单元数量(即fan_in)来调整权重的方差,确保各层激活值的方差保持一致,以此来避免在深层网络中出现梯度消失或爆炸的问题。具体来说,该方法建议将权重初始化为均值为0,方差为 2 / fan_in 2/\text{fan\_in} 2/fan_in的正态分布,其中 fan_in \text{fan\_in} fan_in是权重矩阵中输入单元的数量。
  • 代码实例
def he_init(model):if isinstance(model, nn.Linear):nn.init.kaiming_uniform_(model.weight, mode='fan_in', nonlinearity='relu')if model.bias is not None:nn.init.constant_(model.bias, 0)

4. SVD 初始化

  • 公式:无特定公式。SVD 初始化涉及对权重矩阵进行奇异值分解(SVD),然后根据需要重新组合以初始化网络权重。

SVD(奇异值分解)初始化是一种高级权重初始化技术,它通过对权重矩阵应用奇异值分解来初始化神经网络。这种方法特别适用于需要保持输入数据特征或处理特定矩阵结构(如正交性或特定范数)的场合。

对RNN有比较好的效果。参考论文:(2014 ICLR) Exact solutions to the nonlinear dynamics of learning in deep linear neural networks

SVD 初始化的基本思想是将权重矩阵 W W W 分解为三个矩阵的乘积: W = U Σ V T W = U\Sigma V^T W=UΣVT,其中 U U U V V V 是正交矩阵, Σ \Sigma Σ 是对角矩阵,包含 W W W 的奇异值。初始化过程中,可以通过调整 Σ \Sigma Σ 中的奇异值来控制权重矩阵的性质,如其范数或分布特性,从而影响模型的训练动态和最终性能。

代码实例

在PyTorch中实现SVD初始化可能涉及到使用torch.svd对权重矩阵进行奇异值分解,然后根据分解结果来重构权重矩阵。以下是一个简化的示例:

import torch
import torch.nn as nndef svd_init(model):if isinstance(model, nn.Linear):U, S, V = torch.svd(torch.randn(model.weight.shape))# 可以根据需要调整S中的奇异值model.weight.data = torch.mm(U, torch.mm(torch.diag(S), V.t()))if model.bias is not None:nn.init.constant_(model.bias, 0)

SVD初始化提供了一种灵活的方法来控制神经网络权重的特性,尤其是在需要维护输入特征结构或优化训练稳定性的高级应用中。通过精确控制权重矩阵的奇异值,研究者和工程师可以优化网络的初始化状态,从而提高模型训练的效率和效果。然而,由于其实现相对复杂,通常仅在特定需求下采用此方法。

结论

权重初始化在神经网络训练中起着决定性的作用。选择合适的初始化方法可以显著提高训

练效率和模型性能。在实践中,应根据模型的具体结构和使用的激活函数来选择最适合的初始化方法。以上提到的方法仅是众多初始化技术中的几种,研究者和开发者可以根据需要选择或创新更适合自己模型需求的初始化策略。

本文撰写过程中使用到的其他参考资料

  1. 数据竞赛中如何优化深度学习模型

相关文章:

神经网络权重初始化

诸神缄默不语-个人CSDN博文目录 (如果只想看代码,请直接跳到“方法”一节,开头我介绍我的常用方法,后面介绍具体的各种方案) 神经网络通过多层神经元相互连接构成,而这些连接的强度就是通过权重&#xff…...

代码随想录训练营第三十九天|62.不同路径63. 不同路径 II

62.不同路径 1确定dp数组&#xff08;dp table&#xff09;以及下标的含义 从&#xff08;0&#xff0c;0&#xff09;出发到&#xff08;i&#xff0c;j&#xff09;有 dp[i][j]种路径 2确定递推公式 dp[i][j]dp[i-1][j]dp[i][j-1] 3dp数组如何初始化 for(int i0;i<m…...

学习大数据所需的java基础(5)

文章目录 集合框架Collection接口迭代器迭代器基本使用迭代器底层原理并发修改异常 数据结构栈队列数组链表 List接口底层源码分析 LinkList集合LinkedList底层成员解释说明LinkedList中get方法的源码分析LinkedList中add方法的源码分析 增强for增强for的介绍以及基本使用发2.使…...

Python 光速入门课程

首先说一下&#xff0c;为啥小编在即PHP和Golang之后&#xff0c;为啥又要整Python&#xff0c;那是因为小编最近又拿起了 " 阿里天池 " 的东西&#xff0c;所以小编又不得不捡起来大概五年前学习的Python&#xff0c;本篇文章主要讲的是最基础版本&#xff0c;所以比…...

解决vite打包出现 “default“ is not exported by “node_modules/...问题

项目场景&#xff1a; vue3tsvite项目打包 问题描述 // codemirror 编辑器的相关资源 import Codemirror from codemirror;error during build: RollupError: "default" is not exported by "node_modules/vue/dist/vue.runtime.esm-bundler.js", impor…...

c语言strtok的使用

strtok函数的作用为以指定字符分割字符串&#xff0c;含有两个参数&#xff0c;第一个函数为待分割的字符串或者空指针NULL&#xff0c;第二个参数为分割字符集。 对一个字符串首次使用strtok时第一个参数应该是待分割字符串&#xff0c;strtok以指定字符完成第一次分割后&…...

hash,以及数据结构——map容器

1.hash是什么&#xff1f; 定义&#xff1a;hash,一般翻译做散列、杂凑&#xff0c;或音译为哈希&#xff0c;是把任意长度的输入&#xff08;又叫做预映射pre-image&#xff09;通过散列算法变换成固定长度的输出&#xff0c; 该输出就是散列值。这种转换是一种压缩映射&…...

AIoT网关 人工智能物联网网关

AIoT(人工智能物联网)作为新一代技术的代表&#xff0c;正以前所未有的速度改变着我们的生活方式。在这个智能时代&#xff0c;AIoT网关的重要性日益凸显。它不仅是连接智能设备和应用的关键&#xff0c;同时也是实现智能化家居、智慧城市和工业自动化的必备技术。      一…...

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的鸟类识别系统(Python+PySide6界面+训练代码)

摘要&#xff1a;本文详细阐述了一个利用深度学习进行鸟类识别的系统&#xff0c;该系统集成了最新的YOLOv8算法&#xff0c;并与YOLOv7、YOLOv6、YOLOv5等先前版本进行了性能比较。该系统能够在图像、视频、实时视频流和批量文件中精确地识别和分类鸟类。文中不仅深入讲解了YO…...

核密度分析

一.算法介绍 核密度估计&#xff08;Kernel Density Estimation&#xff09;是一种用于估计数据分布的非参数统计方法。它可以用于多种目的和应用&#xff0c;包括&#xff1a; 数据可视化&#xff1a;核密度估计可以用来绘制平滑的密度曲线或热力图&#xff0c;从而直观地表…...

先进语言模型带来的变革与潜力

用户可以通过询问或交互方式与GPT-4这样的先进语言模型互动&#xff0c;开启通往知识宝库的大门&#xff0c;即时访问人类历史积累的知识、经验与智慧。像GPT-4这样的先进语言模型&#xff0c;能够将人类历史上积累的海量知识和经验整合并加以利用。通过深度学习和大规模数据训…...

重铸安卓荣光——上传图片组件

痛点&#xff1a; 公司打算做安卓软件&#xff0c;最近在研究安卓&#xff0c;打算先绘制样式 研究发现安卓并不像前端有那么多组件库&#xff0c;甚至有些基础的组件都需要自己实现&#xff0c;记录一下自己实现的组件 成品展示 一个上传图片的组件 可以选择拍照或者从相册中…...

Bert基础(四)--解码器(上)

1 理解解码器 假设我们想把英语句子I am good&#xff08;原句&#xff09;翻译成法语句子Je vais bien&#xff08;目标句&#xff09;。首先&#xff0c;将原句I am good送入编码器&#xff0c;使编码器学习原句&#xff0c;并计算特征值。在前文中&#xff0c;我们学习了编…...

Visual Studio快捷键记录

日常使用Visual Studio进行开发&#xff0c;记录一下常用的快捷键&#xff1a; 复制&#xff1a;CtrlC剪切&#xff1a;CtrlX粘贴&#xff1a;CtrlV删除&#xff1a;CtrlL撤销&#xff1a;CtrlZ反撤销&#xff1a;CtrlY查找&#xff1a;CtrlF/CtrlI替换&#xff1a;CtrlH框式选…...

分享84个Html个人模板,总有一款适合您

分享84个Html个人模板&#xff0c;总有一款适合您 84个Html个人模板下载链接&#xff1a;https://pan.baidu.com/s/1GXUZlKPzmHvxtO0sm3gHLg?pwd8888 提取码&#xff1a;8888 Python采集代码下载链接&#xff1a;采集代码.zip - 蓝奏云 学习知识费力气&#xff0c;收集…...

vue使用.sync和update实现父组件与子组件数据绑定的案例

在 Vue 中&#xff0c;.sync 是一个用于实现双向数据绑定的特殊修饰符。它允许父组件通过一种简洁的方式向子组件传递一个 prop&#xff0c;并在子组件中修改这个 prop 的值&#xff0c;然后将修改后的值反馈回父组件&#xff0c;实现双向数据绑定。 使用 .sync 修饰符的基本语…...

C语言系列15——C语言的安全性与防御性编程

目录 写在开头1 缓冲区溢出&#xff1a;如何防范与处理1.1 缓冲区溢出的原因1.2 预防与处理策略 2. 安全的字符串处理函数与使用技巧2.1 strncpy函数2.2 snprintf函数2.3 strlcpy函数2.4 使用技巧 3 防御性编程的基本原则与实际方法3.1 基本原则3.2 实际方法 写在最后 写在开头…...

objectMapper、ObjectNode、JsonNode调用接口时进行参数组装

objectMapper、ObjectNode、JsonNode用于调用接口时进行参数组装 public String sendText( List< String > listUser, String content ) throws JsonProcessingException{if ( listUser.size() < 0 ){return "用户ID为空&#xff01;";}if ( content.lengt…...

2024开年,手机厂商革了自己的命

文&#xff5c;刘俊宏 编&#xff5c;王一粟 2024开年&#xff0c;AI终端的号角已经由手机行业吹响。 OPPO春节期间就没闲着&#xff0c;首席产品官刘作虎在大年三十就迫不及待地宣布&#xff0c;OPPO正式进入AI手机时代。随后在开年后就紧急召开了AI战略发布会&#xff0c;…...

【安全】大模型安全综述

大模型相关非安全综述 LLM演化和分类法 A survey on evaluation of large language models,” arXiv preprint arXiv:2307.03109, 2023.“A survey of large language models,” arXiv preprint arXiv:2303.18223, 2023.“A survey on llm-gernerated text detection: Necess…...

Stable Diffusion 模型分享:AstrAnime(Astr动画)

本文收录于《AI绘画从入门到精通》专栏&#xff0c;专栏总目录&#xff1a;点这里。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五 下载地址 模型介绍 AstrAnime 是一个动漫模型&#xff0c;画风色彩鲜明&#xff0c;擅长绘制漂亮的小姐姐。 条目内容类型大模型…...

【GPTs分享】每日GPTs分享之Canva

简介 Canva&#xff0c;旨在帮助用户通过Canva的用户友好设计平台释放用户的创造力。无论用户是想设计海报、社交媒体帖子还是商业名片&#xff0c;Canva都在这里协助用户将创意转化为现实。 主要功能 设计生成&#xff1a;根据用户的描述和创意需求&#xff0c;生成定制的设…...

【机器学习】数据清洗——基于Pandas库的方法删除重复点

&#x1f388;个人主页&#xff1a;豌豆射手^ &#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏 &#x1f917;收录专栏&#xff1a;机器学习 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共同学习、交流进…...

顺序表增删改查(c语言)

main函数&#xff1a; #include <stdio.h>#include "./seq.h"int main(int argc, const char *argv[]){SeqList* list create_seqList();insert_seqList(list,10);insert_seqList(list,100);insert_seqList(list,12);insert_seqList(list,23);show_seqList(l…...

MyBatis Plus中的动态表名实践

随着数据库应用的不断发展&#xff0c;面对复杂多变的业务需求&#xff0c;动态表名的处理变得愈发重要。在 MyBatis Plus&#xff08;以下简称 MP&#xff09;这一优秀的基于 MyBatis 的增强工具的支持下&#xff0c;我们可以更便捷地应对动态表名的挑战。本文将深入研究如何在…...

JAVA IDEA 项目打包为 jar 包详解

前言 如下简单 maven 项目&#xff0c;现在 maven 项目比较流行&#xff0c;你还没用过就OUT了。需要打包jar 先设置&#xff1a;点击 File > Project Structure > Artifacts > 点击加号 > 选择JAR > 选择From modules with dependencies 一、将所有依赖和模…...

概率基础——几何分布

概率基础——几何分布 介绍 在统计学中&#xff0c;几何分布是描述了在一系列独立同分布的伯努利试验中&#xff0c;第一次成功所需的试验次数的概率分布。在连续抛掷硬币的试验中&#xff0c;每次抛掷结果为正面向上的概率为 p p p&#xff0c;反面向上的概率为 1 − p 1-p …...

JavaScript的内存管理与垃圾回收

前言 JavaScript提供了高效的内存管理机制&#xff0c;它的垃圾回收功能是自动的。在我们创建新对象、函数、原始类型和变量时&#xff0c;所有这些编程元素都会占用内存。那么JavaScript是如何管理这些元素并在它们不再使用时清理它们的呢&#xff1f; 在本节中&#xff0c;…...

Neo4j导入数据之JAVA JDBC

目录结构 前言设置neo4j外部访问代码整理maven 依赖java 代码 参考链接 前言 公司需要获取neo4j数据库内容进行数据筛查&#xff0c;neo4j数据库咱也是头一次基础&#xff0c;辛辛苦苦安装好整理了安装neo4j的步骤&#xff0c;如今又遇到数据不知道怎么创建&#xff0c;关关难…...

LeetCode 2878.获取DataFrame的大小

DataFrame players: ------------------- | Column Name | Type | ------------------- | player_id | int | | name | object | | age | int | | position | object | | … | … | ------------------- 编写一个解决方案&#xff0c;计算并显示 players 的 行数和列数。 将结…...