当前位置: 首页 > news >正文

30分钟吃掉wandb可视化自动调参

wandb.sweep: 低代码,可视化,分布式 自动调参工具。

使用wandb 的 sweep 进行超参调优,具有以下优点。

(1)低代码:只需配置一个sweep.yaml配置文件,或者定义一个配置dict,几乎不用编写调参相关代码。

(2)可视化:在wandb网页中可以实时监控调参过程中每次尝试,并可视化地分析调参任务的目标值分布,超参重要性等。

(3)分布式:sweep采用类似master-workers的controller-agents架构,controller在wandb的服务器机器上运行,agents在用户机器上运行,controller和agents之间通过互联网进行通信。同时启动多个agents即可轻松实现分布式超参搜索。

公众号后台回复关键词:wandb,获取本文notebook代码和B站视频演示。

使用 wandb 的sweep 调参的缺点:

需要联网:由于wandb的controller位于wandb的服务器机器上,wandb日志也需要联网上传,在没有互联网的环境下无法正常使用wandb 进行模型跟踪 以及 wandb sweep 可视化调参。

d6731f3afe349a385fa50a5eb394b50a.png

〇,使用Sweep的3步骤

  1. 配置 sweep_config

配置调优算法,调优目标,需要优化的超参数列表 等等。
  1. 初始化 sweep controller:

sweep_id = wandb.sweep(sweep_config,project)
  1. 启动 sweep agents:

wandb.agent(sweep_id, function=train)
import os,PIL 
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch 
from torch import nn 
import torchvision 
from torchvision import transforms
import datetime
import wandb wandb.login()
from argparse import Namespacedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#初始化参数配置
config = Namespace(project_name = 'wandb_demo',batch_size = 512,hidden_layer_width = 64,dropout_p = 0.1,lr = 1e-4,optim_type = 'Adam',epochs = 15,ckpt_path = 'checkpoint.pt'
)

一. 配置 Sweep config

详细配置文档可以参考:https://docs.wandb.ai/guides/sweeps/define-sweep-configuration

1,选择一个调优算法

Sweep支持如下3种调优算法:

(1)网格搜索:grid. 遍历所有可能得超参组合,只在超参空间不大的时候使用,否则会非常慢。

(2)随机搜索:random. 每个超参数都选择一个随机值,非常有效,一般情况下建议使用。

(3)贝叶斯搜索:bayes. 创建一个概率模型估计不同超参数组合的效果,采样有更高概率提升优化目标的超参数组合。对连续型的超参数特别有效,但扩展到非常高维度的超参数时效果不好。

sweep_config = {'method': 'random'}

2,定义调优目标

设置优化指标,以及优化方向。

sweep agents 通过 wandb.log 的形式向 sweep controller 传递优化目标的值。

metric = {'name': 'val_acc','goal': 'maximize'   }
sweep_config['metric'] = metric

3,定义超参空间

超参空间可以分成 固定型,离散型和连续型。

  • 固定型:指定 value

  • 离散型:指定 values,列出全部候选取值。

  • 连续性:需要指定 分布类型 distribution, 和范围 min, max。用于 random 或者 bayes采样。

sweep_config['parameters'] = {}# 固定不变的超参
sweep_config['parameters'].update({'project_name':{'value':'wandb_demo'},'epochs': {'value': 10},'ckpt_path': {'value':'checkpoint.pt'}})# 离散型分布超参
sweep_config['parameters'].update({'optim_type': {'values': ['Adam', 'SGD','AdamW']},'hidden_layer_width': {'values': [16,32,48,64,80,96,112,128]}})# 连续型分布超参
sweep_config['parameters'].update({'lr': {'distribution': 'log_uniform_values','min': 1e-6,'max': 0.1},'batch_size': {'distribution': 'q_uniform','q': 8,'min': 32,'max': 256,},'dropout_p': {'distribution': 'uniform','min': 0,'max': 0.6,}
})

4,定义剪枝策略 (可选)

可以定义剪枝策略,提前终止那些没有希望的任务。

sweep_config['early_terminate'] = {'type':'hyperband','min_iter':3,'eta':2,'s':3
} #在step=3, 6, 12 时考虑是否剪枝
from pprint import pprint
pprint(sweep_config)

二. 初始化 sweep controller

sweep_id = wandb.sweep(sweep_config, project=config.project_name)

三, 启动 Sweep agent

我们需要把模型训练相关的全部代码整理成一个 train函数。

def create_dataloaders(config):transform = transforms.Compose([transforms.ToTensor()])ds_train = torchvision.datasets.MNIST(root="./mnist/",train=True,download=True,transform=transform)ds_val = torchvision.datasets.MNIST(root="./mnist/",train=False,download=True,transform=transform)ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))dl_train =  torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True,num_workers=2,drop_last=True)dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, num_workers=2,drop_last=True)return dl_train,dl_val
def create_net(config):net = nn.Sequential()net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=config.hidden_layer_width,kernel_size = 3))net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) net.add_module("conv2",nn.Conv2d(in_channels=config.hidden_layer_width,out_channels=config.hidden_layer_width,kernel_size = 5))net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))net.add_module("dropout",nn.Dropout2d(p = config.dropout_p))net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))net.add_module("flatten",nn.Flatten())net.add_module("linear1",nn.Linear(config.hidden_layer_width,config.hidden_layer_width))net.add_module("relu",nn.ReLU())net.add_module("linear2",nn.Linear(config.hidden_layer_width,10))return net
def train_epoch(model,dl_train,optimizer):model.train()for step, batch in enumerate(dl_train):features,labels = batchfeatures,labels = features.to(device),labels.to(device)preds = model(features)loss = nn.CrossEntropyLoss()(preds,labels)loss.backward()optimizer.step()optimizer.zero_grad()return model
def eval_epoch(model,dl_val):model.eval()accurate = 0num_elems = 0for batch in dl_val:features,labels = batchfeatures,labels = features.to(device),labels.to(device)with torch.no_grad():preds = model(features)predictions = preds.argmax(dim=-1)accurate_preds =  (predictions==labels)num_elems += accurate_preds.shape[0]accurate += accurate_preds.long().sum()val_acc = accurate.item() / num_elemsreturn val_acc
def train(config = config):dl_train, dl_val = create_dataloaders(config)model = create_net(config); optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)#======================================================================nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')wandb.init(project=config.project_name, config = config.__dict__, name = nowtime, save_code=True)model.run_id = wandb.run.id#======================================================================model.best_metric = -1.0for epoch in range(1,config.epochs+1):model = train_epoch(model,dl_train,optimizer)val_acc = eval_epoch(model,dl_val)if val_acc>model.best_metric:model.best_metric = val_acctorch.save(model.state_dict(),config.ckpt_path)   nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")#======================================================================wandb.log({'epoch':epoch, 'val_acc': val_acc, 'best_val_acc':model.best_metric})#======================================================================        #======================================================================wandb.finish()#======================================================================return model   #model = train(config)

一切准备妥当,点火🔥🔥。

# 该agent 随机搜索 尝试5次
wandb.agent(sweep_id, train, count=5)

四,调参可视化和跟踪

1,平行坐标系图

可以直观展示哪些超参数组合更加容易获取更好的结果。

7366fd427d9e456030174fd9764948ec.png


2,超参数重要性图

可以显示超参数和优化目标最终取值的重要性,和相关性方向。

79447d4928b8ac70a543e57a9c7141f7.png


caa2b93c396396eb37a888a052bd7cdf.png

相关文章:

30分钟吃掉wandb可视化自动调参

wandb.sweep: 低代码,可视化,分布式 自动调参工具。使用wandb 的 sweep 进行超参调优,具有以下优点。(1)低代码:只需配置一个sweep.yaml配置文件,或者定义一个配置dict,几乎不用编写调参相关代码。(2)可视化…...

【8】AMBA_SOC项目自学IC验证项目-仿真平台脚本使用讲解

仿真平台文件介绍和脚本使用说明 1、项目路径:2、文件夹说明:3、仿真运行命令:第一步:进入项目路径第二步:设置环境第三步:运行仿真第四步:查看波形1、项目路径: 位置:/tool/project/axi 2、文件夹说明: a、env就是放的我们uvm环境相关的env文件; b、out就是我们…...

智慧水务未来技术发展方向预测探讨

随着科技的不断发展和城市化的加速,智慧水务作为一种新的水务模式,逐渐受到广泛关注。未来,智慧水务将会面临更多的技术挑战和商机。本博客将对智慧水务的未来技术发展方向进行预测,以探讨智慧水务未来可能的技术重点。 1. 人工…...

数据结构 | 栈与队列

🔥Go for it!🔥 📝个人主页:按键难防 📫 如果文章知识点有错误的地方,请指正!和大家一起学习,一起进步👀 📖系列专栏:数据结构与算法 &#x1f52…...

Redux 源码分析

Redux 目录结构 redux ├─ .babelrc.js ├─ .editorconfig ├─ .gitignore …...

第五十二章 BFS进阶(二)——双向广搜

第五十二章 BFS进阶(二)——双向广搜一、双向广搜1、优越之处2、实现逻辑3、复杂度分析二、例题1、问题2、分析3、代码一、双向广搜 1、优越之处 双向广搜是指我们从终点和起点同时开始搜索,当二者到达同一个中间状态的时候,即相…...

业务建模题

一. 单选题:1.在活动图中负责在一个活动节点执行完毕后切换到另一个节点的元素是( A)。A.控制流 B.对象流 C.判断节点 D.扩展区城2.以下说法错误的是(C)。A.活动图中的开始标记一般只有一一个,而终止标记可能有多个B.判断节点的出口条件必须保证不互相重复,并且不缺…...

电子秤专用模拟数字(AD)转换器芯片HX711介绍

HX711简介HX711是一款专为高精度电子秤而设计的24 位A/D 转换器芯片。与同类型其它芯片相比,该芯片集成了包括稳压电源、片内时钟振荡器等其它同类型芯片所需要的外围电路,具有集成度高、响应速度快、抗干扰性强等优点。降低了电子秤的整机成本&#xff…...

微服务 RocketMQ-延时消息 消息过滤 管控台搜索问题

~~微服务 RocketMQ-延时消息 消息过滤 管控台搜索问题~~ RocketMQ-延时消息实现延时消息RocketMQ-消息过滤Tag标签过滤SQL标签过滤管控台搜索问题RocketMQ-延时消息 给消息设置延时时间,到一定时间,消费者才能消费的到,中间件内部通过每秒钟扫…...

js发送邮件(node.js)

以前看别人博客留言或者评论文章时必须填写邮箱信息,感觉甚是麻烦。 后来才知道是为了在博主回复后让访客收到邮件,用心良苦。 于是我也在新增留言和文章评论的接口里,新增了给自己发送邮件提醒的功能。 我用的QQ邮箱,具体如下…...

English Learning - Day58 一周高频问题汇总 2023.2.12 周日

English Learning - Day58 一周高频问题汇总 2023.2.12 周日这周主要内容继续说说状语从句结果状语从句这周主要内容 DAY58【周日总结】 一周高频问题汇总 (打卡作业详见 Day59) 一近期主要讲了 一 01.主动脉修饰 以下是最常问到的知识点拓展&#xff…...

【微电网】基于风光储能和需求响应的微电网日前经济调度(Python代码实现)

目录 1 概述 2 知识点及数学模型 3 算例实现 3.1算例介绍 3.2风光参与的模型求解 3.3 风光和储能参与的模型求解 3.5 风光储能和需求响应都参与模型求解 3.6 结果分析对比 4 Python代码及算例数据 1 概述 近年来,微电网、清洁能源等已成为全球关注的热点…...

四种方式的MySQL安装

mysql安装常见的方法有四种序号 安装方式 说明1 yum\rpm简单、快速,不能定制参数2二进制 解压,简单配置就可使用 免安装 mysql-a.b.c-linux2.x-x86_64.tar.gz3源码编译 可以定制参数,安装时间长 mysql-a.b.c.tar.gz4源码制成rpm包 把源码制…...

软考高级信息系统项目管理师系列之九:项目范围管理

软考高级信息系统项目管理师系列之九:项目范围管理 一、范围管理输入、输出、工具和技术表二、范围管理概述三、规划范围管理四、收集需求1.收集需求:2.需求分类3.收集需求的工具与技术4.收集需求过程主要输出5.需求文件内容6.需求管理7.可跟踪性8.双向可跟踪性9.需求跟踪矩阵…...

【项目精选】javaEE健康管理系统(论文+开题报告+答辩PPT+源代码+数据库+讲解视频)

点击下载源码 javaEE健康管理系统主要功能包括:教师登录退出、教师饮食管理、教师健康日志、体检管理等等。本系统结构如下: (1)用户模块: 实现登录功能 实现用户登录的退出 实现用户注册 (2)教…...

ctfshow nodejs

web 334 大小写转换特殊字符绕过。 “ı”.toUpperCase() ‘I’,“ſ”.toUpperCase() ‘S’。 “K”.toLowerCase() ‘k’. payload: CTFſHOW 123456web 335 通过源码可知 eval(xxx),eval 中可以执行 js 代码,那么我们可以依此执行系…...

无线传感器原理及方法|重点理论知识|2021年19级|期末考试

Min-Max定位 【P63】 最小最大法的基本思想是依据未知节点到各锚节点的距离测量值及锚节点的坐标构造若干个边界框,即以参考节点为圆心,未知节点到该锚节点的距离测量值为半径所构成圆的外接矩形,计算外接矩形的质心为未知节点的估计坐标。 多边定位法的浮点运算量大,计算代…...

带你写出符合 Promise/A+ 规范 Promise 的源码

Promise是前端面试中的高频问题,如果你能根据PromiseA的规范,写出符合规范的源码,那么我想,对于面试中的Promise相关的问题,都能够给出比较完美的答案。 我的建议是,对照规范多写几次实现,也许…...

回流与重绘

触发回流与重绘条件👉回流当渲染树中部分或者全部元素的尺寸、结构或者属性发生变化时,浏览器会重新渲染部分或者全部文档的过程就称为 回流。引起回流原因1.页面的首次渲染2.浏览器的窗口大小发生变化3.元素的内容发生变化4.元素的尺寸或者位置发生变化…...

openpyxl表格的简单实用

示例:创建简单的电子表格和条形图 在这个例子中,我们将从头开始创建一个工作表并添加一些数据,然后绘制它。我们还将探索一些有限的单元格样式和格式。 我们将在工作表上输入的数据如下: 首先,让我们加载 openpyxl 并创建一个新工作簿。并获取活动表。我们还将输入我们…...

浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)

✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义(Task Definition&…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合

强化学习(Reinforcement Learning, RL)是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程,然后使用强化学习的Actor-Critic机制(中文译作“知行互动”机制),逐步迭代求解…...

全球首个30米分辨率湿地数据集(2000—2022)

数据简介 今天我们分享的数据是全球30米分辨率湿地数据集,包含8种湿地亚类,该数据以0.5X0.5的瓦片存储,我们整理了所有属于中国的瓦片名称与其对应省份,方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

Spring AI与Spring Modulith核心技术解析

Spring AI核心架构解析 Spring AI(https://spring.io/projects/spring-ai)作为Spring生态中的AI集成框架,其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似,但特别为多语…...

什么是Ansible Jinja2

理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...

Device Mapper 机制

Device Mapper 机制详解 Device Mapper(简称 DM)是 Linux 内核中的一套通用块设备映射框架,为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程,并配以详细的…...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代,加密货币作为一种新兴的金融现象,正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而,加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下,稳定…...

MySQL 知识小结(一)

一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库,分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷,但是文件存放起来数据比较冗余,用二进制能够更好管理咱们M…...

Redis:现代应用开发的高效内存数据存储利器

一、Redis的起源与发展 Redis最初由意大利程序员Salvatore Sanfilippo在2009年开发,其初衷是为了满足他自己的一个项目需求,即需要一个高性能的键值存储系统来解决传统数据库在高并发场景下的性能瓶颈。随着项目的开源,Redis凭借其简单易用、…...

解读《网络安全法》最新修订,把握网络安全新趋势

《网络安全法》自2017年施行以来,在维护网络空间安全方面发挥了重要作用。但随着网络环境的日益复杂,网络攻击、数据泄露等事件频发,现行法律已难以完全适应新的风险挑战。 2025年3月28日,国家网信办会同相关部门起草了《网络安全…...