当前位置: 首页 > news >正文

机器学习洞察 | 一文带你“讲透” JAX

上篇文章中,我们详细分享了 JAX 这一新兴的机器学习模型的发展和优势,本文我们将通过 Amazon SageMaker 示例展示如何部署并使用 JAX。

JAX 的工作机制


JAX 的完整工作机制可以用下面这幅图详细解释:

图片来源:“Intro to JAX” video on YouTube by Jake VanderPlas, Tech leader from JAX team

在图片左侧是开发者自己编写的 Python 代码,JAX 会追踪并变换成 JAX IR 的中间表示,并按照 Python 代码,通过 jax.jit 将其编译成 HLO (High Level Optimized) 代码,代表高级的优化代码,提供给 XLA 进行读取。XLA 在获取编译的 HLO 代码之后,会分配到对应的 CPU、GPU、TPU 或者 ASIC。

对于开发者来说,只需完成您的 Python 代码即可实现这一流程。开发者可以将 JAX 转换视为首先对 Python 函数进行跟踪专门化,然后将其转换为一个小而行为良好的中间形式,然后使用特定于转换的解释规则进行解释。

为什么 JAX 可以在如此小的软件包中提供如此强大的功能呢?

首先,它从熟悉且灵活的编程接口(使用 NumPy 的 Python)开始,并且使用实际的 Python 解释器来完成大部分繁重的工作;其次,它将计算的本质提炼成一个静态具有高阶功能的类型表达式语言,即 Jaxpr 语言。

JAX 应用场景


自 2019 年 JAX 出现之后,使用它的开发者逐年增多。在 2022 年更是达到了非常火热的状态,甚至有人认为它有可能会取代其他的机器学习框架。

支持 JAX 生态的应用场景包括:

  • 深度学习 (Deep Learning):JAX 在深度学习场景下应用很广泛,很多团队基于 JAX 开发了更加高级的 API 支持不同的场景,方便开发者使用。

  • 科学模拟 (Scientific Simulation):JAX 的出现不仅仅是针对于深度学习,其实也拥有很多其他的使命,如科学模拟。

  • 机器人与控制系统 (Robotics and Control Systems)

  • 概率编程 (Probabilistic Programming)

训练和部署深度学习模型


我们用下面这个具体例子展示使用 JAX 来和 Amazon SageMaker 训练和部署深度学习模型,会用到 Amazon SageMaker 的 BYOC 这种模式。

如上图所示,在这个 Amazon SageMaker 的示例中提供了 JAX 的代码示例:https://sagemaker-examples.readthedocs.io/en/latest/advanced_functionality/jax_bring_your_own/train_deploy_jax.html

在 Amazon SageMaker 上基于 JAX 的框架可使用自定义的容器来训练神经网络。

如图的 Amazon SageMaker Examples 提供的 JAX 示例中,我们使用自定义容器在 SageMaker 上 基于 JAX 框架或库训练神经网络。这在单个容器上是可能的,因为我们使用了 sagemaker-training-toolkit,它允许你在自己的自定义容器中使用脚本模式。自定义容器可以使用内置的 SageMaker 训练作业功能,如竞价训练和超参数调整。

训练模型后,您可以将经过训练的模型部署到托管端点。如前所述,SageMaker 具有推理容器,这些容器已针对亚马逊云科技的硬件和常用深度学习框架进行了优化。其中一项优化是针对 TensorFlow 框架的优化。由于 JAX 支持将模型导出为 TensorFlow SavedModel 格式,因此我们使用该功能来展示如何在优化的 SageMaker TensorFlow 推理端点上部署经过训练的模型。

整个训练和部署主要分为以下五个步骤:

  1. 创建 Docker 镜像并将其推送到 Amazon ECR。

  1. 使用 SageMaker 开发工具包传教自定义框架估算器,以便将模型输出归类为 TensorFlowModel。

  1. 代码仓库中有训练估算器的脚本。

  1. 使用 GPU 上的 SageMaker 训练作业来训练每个模型。

  1. 将模型部署到完全托管的终端节点。

下面我们来看看详细步骤:

  1. 创建 Docker 镜像并将其推送到 Amazon ECR。

*创建使用 JAX 训练模型容器的 Dockerfile

Docker 映像是在 NVIDIA 提供的支持 CUDA 的容器之上构建的。为了确保作为 JAX 中功能基础的 jaxlibpackage 支持 CUDA,请从 jax_releases 存储库中下载 jaxlib 软件包。

  • AX releases

https://storage.googleapis.com/jax-releases/jax_releases.html

这里需要注意的是:为了确保作为 JAX 中的功能基础的 JAX library package 能够支持 cuda,建议在去做这个创建自定义容器时,去看一下目前 JAX release 这个存储库中,它下载的这个 JAX library 包的版本号或者相关注意事项等等。

2、使用 SageMaker 开发工具包创建自定义框架估算器,以便将模型输出归类为 TensorFlowModel。

创建基本 SageMaker 框架估算器的子类,将估算器的模型类型指定为 TensorFlow 模型。为此,我们指定了一个自定义 create_model 方法,该方法使用现有的 TensorFlowModel 类来启动推理容器。

3、通过代码仓库训练估算器的脚本。

您可以通过传统的 SageMaker Python SDK 工作流通过模型执行训练、部署和运行推理。我们确保导入并初始化自定义框架估算器的代码片段中定义的 JaxEstimator,然后运行标准的 .fit () 和 .deploy () 调用。

对于 JAX ,可以调用 jax2tf 函数来执行相同的操作。代码在存储库中可用。设置正确的路径 /opt/ml/model/1 非常重要,这是 SageMaker wrapper(封装器) 假定模型已存储的地方。、

前面提到的 JAX 和 TF 的互操作性,目前 JAX 是通过 JAX to TF 这样的一个软件包,来为 JAX 和 TF 的互操作性提供支持,那 jax2tf.convert 是用于在 TensorFlow 的上下文中使用 JAX 函数,那 jax2tf.call_tf 是用于在 JAX 的上下文中使用的 TensorFlow 函数互操作来完成的。

4、使用 GPU 上的 SageMaker 训练作业来训练每个模型。

  1. 将模型部署到完全托管的终端节点。

vanilla_jax_predictor = vanilla_jax_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge"
)
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
def test_image(predictor, test_images, test_labels, image_number):np_img 
= np.expand_dims(np.expand_dims(test_images[image_number], axis=
-1
), axis=
0
)result = predictor.predict(np_img)pred_y = np.argmax(result["predictions"])print("True Label:", test_labels[image_number])print("Predicted Label:", pred_y)plt.imshow(test_images[image_number]

*部署和准备输入的测试图像

*进行推理

有关在 Amazon SageMaker 上使用 JAX 训练和部署深度学习模型的详细过程和代码,请参考亚马逊云科技官方博客

如图所示,上面的两张图是一个部署模型的例子,下面的图是进行推理的例子。由于我们的 Framework Estimator 知道模型将使用 TensorFlowModel 提供服务,因此部署这些端点只是对 estimator.deploy () 方法做调用即可。

参考资料

  • Training and Deploying ML Models using JAX on SageMaker

  • Train and deploy deep learning models using JAX with Amazon SageMaker

  • AX core from scratch

  • Building JAX from source

JAX 是一种越来越流行的库,它支持原生 Python 或 NumPy 函数的可组合函数转换,可用于高性能数值计算和机器学习研究。JAX 提供了编写 NumPy 程序的能力,这些程序可以使用 GPU/TPU 自动差分和加速,从而形成了更灵活的框架来支持现代深度学习架构。在这两篇文章中我们讨论了有关 JAX 的一些主题,希望对您用使用 JAX 这一框架进行深度学习研究有所帮助。

往期推荐


  • 机器学习洞察 | JAX,机器学习领域的“新面孔”

  • 机器学习洞察 | 降本增效,无服务器推理是怎么做到的?

  • 机器学习洞察 | 分布式训练让机器学习更加快速准确

相关文章:

机器学习洞察 | 一文带你“讲透” JAX

在上篇文章中,我们详细分享了 JAX 这一新兴的机器学习模型的发展和优势,本文我们将通过 Amazon SageMaker 示例展示如何部署并使用 JAX。JAX 的工作机制JAX 的完整工作机制可以用下面这幅图详细解释:图片来源:“Intro to JAX” video on YouT…...

OpenFaaS介绍

FaaS 云计算时代出现了大量XaaS形式的概念,从IaaS(Infrastructure as a Service)、PaaS(Platform as a Service)、SaaS(Software as a Service)到容器云引领的CaaS(Containers as a Service),再到火热的微服务架构,它们都在试着将各种软、硬…...

【算法设计与分析】STL容器、递归算法、分治法、蛮力法、回溯法、分支限界法、贪心法、动态规划;各类算法代码汇总

文章目录前言一、STL容器二、递归算法三、分治法四、蛮力法五、回溯法六、分支限界法七、贪心法八、动态规划前言 本篇共为8类算法(STL容器、递归算法、分治法、蛮力法、回溯法、分支限界法、贪心法、动态规划),则各取每类算法中的几例经典示例进行展示。 一、STL容…...

vue初识

第一次接触vue,前端的html,css,jquery,js学习也有段时间了,就照着B站的视频简单看了一些,了解了一些简单的用法,这边做一个记录。 官网 工具:使用VSCode以及Live Server插件(能够实时预览) 第…...

火山引擎入选《2022 爱分析 · DataOps 厂商全景报告》,旗下 DataLeap 产品能力获认可

更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 2 月 9 日,国内领先的数字化市场研究与咨询机构爱分析发布了《2022 爱分析DataOps 厂商全景报告》(以下简称报告),报…...

java-spring_bean的生命周期

生命周期:从创建到消亡的完整过程初始化容器 1. 创建对象(内存分配 ) 2. 执行构造方法 3. 执行属性注入(set操作) 4. 执行bean初始化方法 使用bean 执行业务操作 关闭/销毁容器 1.执行bean销毁方法 bean销毁时机 容…...

微服务相关概念

一、谈谈你对微服务的理解,微服务有哪些优缺点?微服务是由Martin Fowler大师提出的。微服务是一种架构风格,通过将大型的单体应用划分为比较小的服务单元,从而降低整个系统的复杂度。优点:1、服务部署更灵活&#xff1…...

论文解读:(TransA)TransA: An Adaptive Approach for Knowledge Graph Embedding

简介 先前的知识表示方法:TransE、TransH、TransR、TransD、TranSparse等。的损失函数仅单纯的考虑hrh rhr和ttt在某个语义空间的欧氏距离,认为只要欧式距离最小,就认为h和th和th和t的关系为r。显然这种度量指标过于简单,虽然先…...

js将数字转十进制+十六进制(联动el-ui下拉选择框)

十进制与十六进制的整数转化一、十进制转十六进制二、十六进制转十进制三、联动demo一、十进制转十六进制 正则表达式: /^([0-9]||([1-9][0-9]{0,}))$/解析:[0-9]代表个位数,([1-9][0-9]{0,})代表十位及以上 二、十六进制转十进制 正则表达…...

关于RedissonLock的一些所思

关于RedissonClient.getLock() 我们一般的使用Redisson的方式就是: RLock myLock redissonClient.getLock("my_order");//myLock.lock();//myLock.tryLock();就上面的例子里,如果某个线程已经拿到了my_order的锁,那别的线程调用m…...

C++:倒牛奶问题

文章目录题目一、输入二、输出三、思路代码题目 农业,尤其是生产牛奶,是一个竞争激烈的行业。Farmer John发现如果他不在牛奶生产工艺上有所创新,他的乳制品生意可能就会受到重创! 幸运的是,Farmer John想出了一个好主…...

MySQL8.x group_by报错的4种解决方法

在我们使用MySQL的时候总是会遇到各种各样的报错,让人头痛不已。其中有一种报错,sql_modeonly_full_group_by,十分常见,每次都是老长的一串出现,然后带走你所有的好心情,如:LIMIT 0, 1000 Error…...

具有非线性动态行为的多车辆列队行驶问题的基于强化学习的方法

论文地址: Reinforcement Learning Based Approach for Multi-Vehicle Platooning Problem with Nonlinear Dynamic Behavior 摘要 协同智能交通系统领域的最新研究方向之一是车辆编队。研究人员专注于通过传统控制策略以及最先进的深度强化学习 (RL) 方法解决自动…...

TrueNas篇-硬盘直通

硬盘直通 在做硬盘直通之前,在trueNas(或者其他虚拟机)内是检测不到安装的硬盘的。 在pve节点查看硬盘信息 打开pve的shell控制台 输入下面的命令查看硬盘信息: ls -l /dev/disk/by-id/该命令会显示出实际所有的硬盘设备信息,其中ata代…...

手机子品牌的“性能战事”:一场殊途同归的大混战

在智能手机行业进入存量市场后,竞争更加白热化。当各国产手机品牌集体冲高端,旗下子品牌们也正厮杀正酣,显现出刀光剑影。处理器、屏幕、内存、价格等各方面无不互相对标,激烈程度并不亚于高端之争。源于OPPO的中端手机品牌realme…...

dockerfile自定义镜像安装jdk8,nginx,后端jar包和前端静态文件,并启动容器访问

dockerfile自定义镜像安装jdk8,nginx,后端jar包和前端静态文件,并启动容器访问简介centos7系统里面我准备的服务如下:5gsignplay-web静态文件内容如下:nginx.conf配置文件内容如下:Dockerfile内容如下:run.sh启动脚本内容如下:制作镜像并启动访问简介 通过用docker…...

MongoDB 全文检索

MongoDB 全文检索 全文检索对每一个词建立一个索引,指明该词在文章中出现的次数和位置,当用户查询时,检索程序就根据事先建立的索引进行查找,并将查找的结果反馈给用户的检索方式。 这个过程类似于通过字典中的检索字表查字的过…...

JS中声明变量,使用 var、let、const的区别

一、var 的使用 1.1、var 的作用域 1、var可以在全局范围声明或函数/局部范围内声明。当在最外层函数的外部声明var变量时,作用域是全局的。这意味着在最外层函数的外部用var声明的任何变量都可以在windows中使用。 2、当在函数中声明var时,作用域是局…...

汽车改装避坑指南:大尾翼

今天给大家讲一个改装的误区:大尾翼 很多车友看到一些汽车加了大尾翼,非常的好看,就想给自己的车也加装一个。 那你有没有想过,尾翼这东西你真的需要吗? 赛车为什么加尾翼?尾翼主要是给车尾部的一个压低提供…...

【Unity资源下载】POLYGON Dungeon Realms - Low Poly 3D Art by Synty

$149.99 Synty Studios 一个史诗般的低多边形资产包,包括人物、道具、武器和环境资产,用于创建一个以奇幻为主题的多边形风格游戏。 模块化的部分很容易在各种组合中拼凑起来。 包包含超过1,118个详细预制件。 主要特点 ◼ ◼ 完全模块化的地下城!包…...

visual studio 2022更改主题为深色

visual studio 2022更改主题为深色 点击visual studio 上方的 工具-> 选项 在选项窗口中,选择 环境 -> 常规 ,将其中的颜色主题改成深色 点击确定,更改完成...

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话: “利润不是赚出来的,是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业,很多企业看着销售不错,账上却没钱、利润也不见了,一翻库存才发现: 一堆卖不动的旧货…...

【决胜公务员考试】求职OMG——见面课测验1

2025最新版!!!6.8截至答题,大家注意呀! 博主码字不易点个关注吧,祝期末顺利~~ 1.单选题(2分) 下列说法错误的是:( B ) A.选调生属于公务员系统 B.公务员属于事业编 C.选调生有基层锻炼的要求 D…...

【开发技术】.Net使用FFmpeg视频特定帧上绘制内容

目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法,当前调用一个医疗行业的AI识别算法后返回…...

.Net Framework 4/C# 关键字(非常用,持续更新...)

一、is 关键字 is 关键字用于检查对象是否于给定类型兼容,如果兼容将返回 true,如果不兼容则返回 false,在进行类型转换前,可以先使用 is 关键字判断对象是否与指定类型兼容,如果兼容才进行转换,这样的转换是安全的。 例如有:首先创建一个字符串对象,然后将字符串对象隐…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

USB Over IP专用硬件的5个特点

USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中,从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备(如专用硬件设备),从而消除了直接物理连接的需要。USB over IP的…...

Python 实现 Web 静态服务器(HTTP 协议)

目录 一、在本地启动 HTTP 服务器1. Windows 下安装 node.js1)下载安装包2)配置环境变量3)安装镜像4)node.js 的常用命令 2. 安装 http-server 服务3. 使用 http-server 开启服务1)使用 http-server2)详解 …...

Ubuntu Cursor升级成v1.0

0. 当前版本低 使用当前 Cursor v0.50时 GitHub Copilot Chat 打不开,快捷键也不好用,当看到 Cursor 升级后,还是蛮高兴的 1. 下载 Cursor 下载地址:https://www.cursor.com/cn/downloads 点击下载 Linux (x64) ,…...

AI语音助手的Python实现

引言 语音助手(如小爱同学、Siri)通过语音识别、自然语言处理(NLP)和语音合成技术,为用户提供直观、高效的交互体验。随着人工智能的普及,Python开发者可以利用开源库和AI模型,快速构建自定义语音助手。本文由浅入深,详细介绍如何使用Python开发AI语音助手,涵盖基础功…...