当前位置: 首页 > news >正文

Pytorch训练固定随机种子(单卡场景和分布式训练场景)

模型的训练是一个随机过程,固定随机种子可以帮助我们复现实验结果。

接下来介绍一个模型训练过程中固定随机种子的代码,并对每条语句的作用都会进行解释。

def seed_reproducer(seed=2333):random.seed(seed)os.environ["PYTHONHASHSEED"] = str(seed)np.random.seed(seed)torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.enabled = True

这是一个自定义函数,函数的参数就是我们传入的种子的数值,类型为int,作用就是消除训练过程中的随机性,以确保实验得可重复性,具体使用方法为在初始化模型和dataset前调用该函数即可。

接下来逐句讲解每个语句的作用。


random.seed(seed)

  • 作用:设置python内置random模块的种子,确保所有使用random模块产生的随机数序列是确定的;
  • 应用场景:适用于任何使用random模块的地方,例如数据增强、采样等。

os.environ["PYTHONHASHSEED"] = str(seed)

  • 作用:设置python的哈希种子,python字典和其他哈希表结构依赖于哈希函数,而哈希函数的行为在不同运行之间可能会不同,通过设置PYTHONHASHSEED环境变量,可使哈希结果在同一种子下保持一致;
  • 应用场景:确保字典键值对顺序的一致性,避免因哈希碰撞引起的非确定性行为。

np.random.seed(seed)

  • 作用:设置Numpy库的随机种子;
  • 应用场景:适用于所有使用Numpy生成随机数的地方,例如初始化权重、数据打乱等。

torch.manual_seed(seed)

  • 作用:为Pytorch设置全局种子,确保所有使用Pytorch身材的随机数(包括张量操作)都是确定性的;
  • 应用场景:适用于所有使用Pytorch生成随机数的地方。

 为了方便结合代码进行理解,在这里单独把后半部分的代码复制一下。

if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.enabled = True

if torch.cuda.is_available():

条件判断,只有当CUDA可用时才进行以下设置,这确保了代码可以在没有GPU的环境中正常运行。

torch.cuda.manual_seed(seed)

  • 作用:为当前GPU设置随机种子,这确保了所有在当前GPU上生成的随机数都是确定的;
  • 应用场景:适用于单个GPU的随机数生成。

torch.cuda.manual_seed_all(seed)

  • 作用:为所有可用的GPU设置相同的随机种子,这确保了在多GPU环境中,每个GPU上随机生成的随机数都是一致的;
  • 应用场景:适用于多GPU环境中的随机数生成。

torch.backends.cudnn.deterministic = True

  • 作用:确保cuDNN使用确定性的算法,某些cuDNN算法是具有随机性的,启用此选项可以提高结果的可重复性,但是可能会降低性能;
  • 应用场景:适用于需要严格可重复性的实验。

torch.backends.cudnn.benchmark = False

  • 作用:禁用cuDNN的自动选择最佳卷积算法的功能,默认情况下cuDNN会在首次运行时尝试找到最适合硬件的算法,这可能会导致结果的不确定性,禁用此选项可以确保每次都是用相同的算法;
  • 应用场景:适用于需要严格可重复性的实验。

torch.backends.cudnn.enabled = True

  • 作用:启用cuDNN,虽然设置了deterministicbenchmark参数来控制cuDNN的行为,但仍然需要确保cuDNN是启用的;
  • 应用场景:所有需要使用GPU加速pytorch计算的场景。

多卡训练的情况

对于多卡训练的情况,设置随机种子的方式需要特别注意,以确保每个进程(或者称为“rank”)生成的随机数序列是不同的,同时还需要保证整个训练过程的可重复性。基于上述要求,随机种子可使用如下代码进行修改:

seed = args.seed + utils.get_rank()
  • args.seed:我们指定的seed的数值,通常在模型训练的配置文件中,或者通过命令行参数传入模型的训练脚本;
  • utils.get_rank():这是自定义的一个函数,位于自定义库utils中,具体作用是获得当前进程的全局序号,比如在进行分布式训练时,有2台机器(我们成之为2个节点),每台机器有8张GPU,则一共会有2*8=16个进程,每个进程都会有唯一的序号,从0~15。

因此我们看到,对于分布式训练场景,随机种子只需要确定一个全局种子(args.seed)在加上一个增量(utils.get_rank()),而这个增量对于每个进程来说是固定的。

下面是get_rank()函数的具体实现:

import torch.distributed as distdef is_dist_avail_and_initialized():# 判断当前环境中是否支持分布式训练if not dist.is_available():return False# 检查当前环境是否已经成功初始化了分布式训练环境if not dist.is_initialized():return Falsereturn Truedef get_rank():# 判断分布式训练是否可用且是否已成功初始化if not is_dist_avail_and_initialized():return 0return dist.get_rank()

相关文章:

Pytorch训练固定随机种子(单卡场景和分布式训练场景)

模型的训练是一个随机过程,固定随机种子可以帮助我们复现实验结果。 接下来介绍一个模型训练过程中固定随机种子的代码,并对每条语句的作用都会进行解释。 def seed_reproducer(seed2333):random.seed(seed)os.environ["PYTHONHASHSEED"] s…...

Conda + JuiceFS :增强 AI 开发环境共享能力

Conda 是当前 AI 应用开发领域中非常流行的环境和包管理系统,因其能够简单便捷地创建与系统资源相隔离的虚拟环境广受欢迎。 Conda 支持在不同的操作系统上重建相同的工作环境,但在环境共享复用方面仍存在一些挑战。比如,在不同机器上复用相…...

人工智能-人机交互的机会

目录 引言HCI领域的发展机会人工智能领域的崛起与机会博雅智信的HCI与AI辅导服务结语 引言 在人类科技不断进步的今天,HCI(人机交互)和人工智能(AI)是两个密切相关且充满潜力的领域。HCI研究如何优化人类与计算机之间…...

【系统架构核心服务设计】使用 Redis ZSET 实现排行榜服务

目录 一、排行榜的应用场景 二、排行榜技术的特点 三、使用Redis ZSET实现排行榜 3.1 引入依赖 3.2 配置Redis连接 3.3 创建实体类(可选) 3.4 编写 Redis 操作服务层 3.5 编写控制器层 3.6 测试 3.6.1 测试 addMovieScore 接口 3.6.2 测试 g…...

elasticsearch基础总结

最近实习,项目用的elasticseatch做的存储库,但是之前对于es接触的不多,查询语法有些不熟,每次想写个DSL查询时都要gpt或者施展搜索大法,所以索性就自己总结总结,以后忘了也方便查。所以这篇文章会持续更新。…...

【慕伏白教程】Zerotier 连接与简单配置

文章目录 下载与安装WindowsLinuxapt安装官方脚本安装 Zerotier 配置新建网络网络配置 终端配置WindowsLinux 下载与安装 Windows 进入Zerotier官方下载网站,点击下载 在下载目录找到安装文件,双击打开后点击 Install 开始安装 安装完成后,…...

Brain.js(九):LSTMTimeStep 实战教程 - 未来短期内的股市指数预测 - 实操要谨慎

系列的前一文RNNTimeStep 实战教程 - 股票价格预测 讲述了如何使用RNN时间序列预测实时的股价, 在这一节中,我们将深入学习如何利用 JavaScript 在浏览器环境下使用 LSTMTimeStep 进行股市指数的短期预测。通过本次实战教程,你将了解到如何用…...

C# 字符串(String)

文章目录 前言创建 String 对象的方式1. 通过给 String 变量指定一个字符串2. 通过使用 String 类构造函数3. 通过使用字符串串联运算符( )4. 通过检索属性或调用一个返回字符串的方法5. 通过格式化方法来转换一个值或对象为它的字符串表示形式 String …...

二进制文件

大多数人听到“二进制”的时候,脑海里可能马上就会联想到电影《黑客帝国》中由“0”和“1”组成的矩阵。 笔者不打算在这里详细讨论二进制的运算、反码、补码之类枯燥的东西,但有几个和开发相关的概念需要做一点澄清和普及。因为这些内容就像空气——用…...

【电子元器件】音频功放种类

本文章是笔者整理的备忘笔记。希望在帮助自己温习避免遗忘的同时,也能帮助其他需要参考的朋友。如有谬误,欢迎大家进行指正。 一、概述 音频功放将小信号的幅值提高至有用电平,同时保留小信号的细节,这称为线性度。放大器的线性…...

linux之vim

一、模式转换命令 vim主要有三种模式:命令模式(Normal Mode)、输入模式(Insert Mode)和底线命令模式(Command-Line Mode)。 从命令模式切换到输入模式:i:在当前光标所在…...

QT的ui界面显示不全问题(适应高分辨率屏幕)

//自动适应高分辨率 QCoreApplication::setAttribute(Qt::AA_EnableHighDpiScaling);一、问题 电脑分辨率高,默认情况下,打开QT的ui界面,显示不全按钮内容 二、解决方案 如果自己的电脑分辨率较高,可以尝试以下方案:自…...

数据结构--串、数组和广义表

串 定义:串(String)是由零个或多个字符组成的有限序列。 子串:串中任意个连续字符组成的子序列称为该串的子串。 主串:包含子串的串相应地称为主串。 字符位置:字符在该序列中的序号为该字符在串中的位置…...

LLMs之Agent之Lares:Lares的简介、安装和使用方法、案例应用之详细攻略

LLMs之Agent之Lares:Lares的简介、安装和使用方法、案例应用之详细攻略 导读:这篇博文介绍了 Lares,一个由简单的 AI 代理驱动的智能家居助手模拟器,它展现出令人惊讶的解决问题能力。 >> 背景痛点:每天都有新的…...

1-1.mysql2 之 mysql2 初识(mysql2 初识案例、初识案例挖掘)

一、mysql2 概述 mysql2 是一个用于 Node.js 的 MySQL 客户端库 mysql2 是 mysql 库的一个改进版本,提供了更好的性能和更多的功能 使用 mysql2 之前,需要先安装它 npm install mysql2 二、mysql2 初识案例 1、数据库准备 创建数据库 testdb CREAT…...

企业邮箱为什么不能经常群发邮件?

企业邮箱是用企业域名作为后缀的邮箱,虽然企业邮箱确实具备群发邮件的功能,但它更适用于企业内部的群发,而非用于外部推广。如果是在企业邮件域内进行群发,通常可以借助企业邮箱的邮件列表来实现。然而,对于域外的大量…...

集成运算放大电路反馈判断

集成运算放大电路 一种具有很高放大倍数的多级直接耦合放大电路,因最初用于信号运算而得名,简称集成运放或运放 模拟集成电路中的典型组件,是发展最快、品种最多、应用最广的一种 反馈 将放大电路输出信号的一部分或全部通过某种电路引回到输…...

媒体查询、浏览器一帧渲染过程

文章目录 媒体查询语法示例根据视口宽度应用不同的样式根据设备像素比应用不同的样式根据方向应用不同的样式 使用场景 浏览器一帧的渲染过程 媒体查询 媒体查询(Media Query)是CSS3中的一个重要特性,它允许开发者根据设备的特定条件&#x…...

高级排序算法(一):快速排序详解

引言 当我们处理大规模数据时,像冒泡排序、选择排序这样的基础排序算法就有点力不从心了。这时候,快速排序(Quick Sort)就派上用场了。 作为一种基于分治法的高效排序算法,快速排序在大多数情况下可以在O(n log n)的时…...

3.2 网络协议IP

欢迎大家订阅【计算机网络】学习专栏,开启你的计算机网络学习之旅! 文章目录 1 定义2 虚拟互连网络3 分组在互联网中的传送4 IPv4 地址 1 定义 网际协议 IP是 TCP/IP 体系中两个最主要的协议之一,也是最重要的互连网协议之一。IPv4 和 IPv6 …...

2024 一带一路暨金砖国家技能发展与技术创新大赛【网络安全防护治理实战技能赛项】样题(中职组)

2024 一带一路暨金砖国家技能发展与技术创新大赛【网络安全防护治理实战技能赛项】样题(中职组) 1.基础设置和安全强化(xxx 分)1.3. 任务内容: 2.安全监测和预警(xxx 分)2.1. 任务一:建立目录安…...

excel如何让单元格选中时显示提示信息?

现象: 当鼠标放在单元格上,会出现提示信息: 先选中单元格选择上方的【数据】-【数据验证】图标选择【输入信息】勾上【选定单元格时显示输入信息】输入【标题】,如:最上方图中的:姓名:输入【输…...

oscp备考,oscp系列——Kioptix Level 3靶场

Kioptix Level 3 oscp备考,oscp系列——Kioptix Level 3靶场 nmap扫描 主机发现 └─# nmap -sn 192.168.80.0/24 Starting Nmap 7.94SVN ( https://nmap.org ) at 2024-12-09 00:33 CST Nmap scan report for 192.168.80.1 Host is up (0.00014s latency). MAC…...

信创改造-达梦数据库配置项 dm.ini 优化

设置模式:兼容MySQL,COMPATIBLE_MODE 4 内存占比:90%,MAX_OS_MEMORY 90 目标内存:2G(不影响申请内存超过2G,但这部分内存不会回收),MEMORY_TARGET 2000 参考 https:…...

日本IT-需要掌握哪些技术框架?一篇通读

在日本从事IT工作,需要掌握的技术框架与全球范围内的趋势相似,但也有一些特定的技术和框架在日本更为流行。以下是一些在日本IT行业中常用的技术框架: Java后端 Java语言:Java在日本是一门非常稳定且受欢迎的编程语言&#xff0…...

错题:Linux C语言

题目&#xff1a;手写代码&#xff1a;判断一个数&#xff08;int类型的整数&#xff09;中有有多少1 题目&#xff1a;手写代码&#xff1a;判断一个数(转换成二进制表示时)有几个1 #include <stdio.h> int main(int argc, const char *argv[]) { //判断一个数&#xf…...

多表设计-一对多一对多-外键

一.多表设计概述&#xff1a; 二.一对多&#xff1a; 1.需求&#xff1a; 根据 页面原型 及 需求文档&#xff0c;完成部门及员工模块的表结构设计 -->部门和员工就是一对多&#xff0c;因为一个部门下会有多个员工&#xff0c;但一个员工只归属一个部门 2.页面原型&…...

Ch1:古今的manipulation与仿真、ROS和Drake介绍

不同的机器人研究与仿真 以前&#xff08;15年左右&#xff09;只能用仿真环境训练行走机器人&#xff0c;对于manipulation任务&#xff0c;有两个问题&#xff1a;1&#xff09;相机不真实&#xff1b;2&#xff09;接触行为太复杂。 I remember just a few years ago (~201…...

JAVA秋招面试题精选-第一天总结

目录 分栏简介&#xff1a; 问题一&#xff1a;订单表每天新增500W条数据&#xff0c;分库分表应该怎么设计&#xff1f; 问题难度以及频率&#xff1a; 问题导向&#xff1a; 满分答案&#xff1a; 举一反三&#xff1a; 问题总结&#xff1a; 问题二&#xff1a;解释…...

服务器卸载安装的 Node.js

卸载安装的 Node.js 版本&#xff0c;具体步骤取决于你是通过包管理器&#xff08;如 yum 或 dnf&#xff09;安装的&#xff0c;还是通过 nvm (Node Version Manager) 安装的。以下是针对这两种情况的指南。 通过包管理器卸载 Node.js 如果你是通过 yum 或 dnf 安装的 Node.…...