網(wǎng)站建設(shè)價格明細(xì)表和網(wǎng)站預(yù)算網(wǎng)站推廣常用方法
在上篇文章中,我們詳細(xì)分享了 JAX 這一新興的機(jī)器學(xué)習(xí)模型的發(fā)展和優(yōu)勢,本文我們將通過 Amazon SageMaker 示例展示如何部署并使用 JAX。
JAX 的工作機(jī)制
JAX 的完整工作機(jī)制可以用下面這幅圖詳細(xì)解釋:

圖片來源:“Intro to JAX” video on YouTube by Jake VanderPlas, Tech leader from JAX team
在圖片左側(cè)是開發(fā)者自己編寫的 Python 代碼,JAX 會追蹤并變換成 JAX IR 的中間表示,并按照 Python 代碼,通過 jax.jit 將其編譯成 HLO (High Level Optimized) 代碼,代表高級的優(yōu)化代碼,提供給 XLA 進(jìn)行讀取。XLA 在獲取編譯的 HLO 代碼之后,會分配到對應(yīng)的 CPU、GPU、TPU 或者 ASIC。
對于開發(fā)者來說,只需完成您的 Python 代碼即可實現(xiàn)這一流程。開發(fā)者可以將 JAX 轉(zhuǎn)換視為首先對 Python 函數(shù)進(jìn)行跟蹤專門化,然后將其轉(zhuǎn)換為一個小而行為良好的中間形式,然后使用特定于轉(zhuǎn)換的解釋規(guī)則進(jìn)行解釋。
為什么 JAX 可以在如此小的軟件包中提供如此強(qiáng)大的功能呢?
首先,它從熟悉且靈活的編程接口(使用 NumPy 的 Python)開始,并且使用實際的 Python 解釋器來完成大部分繁重的工作;其次,它將計算的本質(zhì)提煉成一個靜態(tài)具有高階功能的類型表達(dá)式語言,即 Jaxpr 語言。
JAX 應(yīng)用場景
自 2019 年 JAX 出現(xiàn)之后,使用它的開發(fā)者逐年增多。在 2022 年更是達(dá)到了非?;馃岬臓顟B(tài),甚至有人認(rèn)為它有可能會取代其他的機(jī)器學(xué)習(xí)框架。
支持 JAX 生態(tài)的應(yīng)用場景包括:

深度學(xué)習(xí) (Deep Learning):JAX 在深度學(xué)習(xí)場景下應(yīng)用很廣泛,很多團(tuán)隊基于 JAX 開發(fā)了更加高級的 API 支持不同的場景,方便開發(fā)者使用。
科學(xué)模擬 (Scientific Simulation):JAX 的出現(xiàn)不僅僅是針對于深度學(xué)習(xí),其實也擁有很多其他的使命,如科學(xué)模擬。
機(jī)器人與控制系統(tǒng) (Robotics and Control Systems)
概率編程 (Probabilistic Programming)
訓(xùn)練和部署深度學(xué)習(xí)模型
我們用下面這個具體例子展示使用 JAX 來和 Amazon SageMaker 訓(xùn)練和部署深度學(xué)習(xí)模型,會用到 Amazon SageMaker 的 BYOC 這種模式。

如上圖所示,在這個 Amazon SageMaker 的示例中提供了 JAX 的代碼示例:https://sagemaker-examples.readthedocs.io/en/latest/advanced_functionality/jax_bring_your_own/train_deploy_jax.html
在 Amazon SageMaker 上基于 JAX 的框架可使用自定義的容器來訓(xùn)練神經(jīng)網(wǎng)絡(luò)。
如圖的 Amazon SageMaker Examples 提供的 JAX 示例中,我們使用自定義容器在 SageMaker 上 基于 JAX 框架或庫訓(xùn)練神經(jīng)網(wǎng)絡(luò)。這在單個容器上是可能的,因為我們使用了 sagemaker-training-toolkit,它允許你在自己的自定義容器中使用腳本模式。自定義容器可以使用內(nèi)置的 SageMaker 訓(xùn)練作業(yè)功能,如競價訓(xùn)練和超參數(shù)調(diào)整。
訓(xùn)練模型后,您可以將經(jīng)過訓(xùn)練的模型部署到托管端點。如前所述,SageMaker 具有推理容器,這些容器已針對亞馬遜云科技的硬件和常用深度學(xué)習(xí)框架進(jìn)行了優(yōu)化。其中一項優(yōu)化是針對 TensorFlow 框架的優(yōu)化。由于 JAX 支持將模型導(dǎo)出為 TensorFlow SavedModel 格式,因此我們使用該功能來展示如何在優(yōu)化的 SageMaker TensorFlow 推理端點上部署經(jīng)過訓(xùn)練的模型。
整個訓(xùn)練和部署主要分為以下五個步驟:
創(chuàng)建 Docker 鏡像并將其推送到 Amazon ECR。
使用 SageMaker 開發(fā)工具包傳教自定義框架估算器,以便將模型輸出歸類為 TensorFlowModel。
代碼倉庫中有訓(xùn)練估算器的腳本。
使用 GPU 上的 SageMaker 訓(xùn)練作業(yè)來訓(xùn)練每個模型。
將模型部署到完全托管的終端節(jié)點。
下面我們來看看詳細(xì)步驟:
創(chuàng)建 Docker 鏡像并將其推送到 Amazon ECR。

*創(chuàng)建使用 JAX 訓(xùn)練模型容器的 Dockerfile
Docker 映像是在 NVIDIA 提供的支持 CUDA 的容器之上構(gòu)建的。為了確保作為 JAX 中功能基礎(chǔ)的 jaxlibpackage 支持 CUDA,請從 jax_releases 存儲庫中下載 jaxlib 軟件包。
AX releases
https://storage.googleapis.com/jax-releases/jax_releases.html
這里需要注意的是:為了確保作為 JAX 中的功能基礎(chǔ)的 JAX library package 能夠支持 cuda,建議在去做這個創(chuàng)建自定義容器時,去看一下目前 JAX release 這個存儲庫中,它下載的這個 JAX library 包的版本號或者相關(guān)注意事項等等。
2、使用 SageMaker 開發(fā)工具包創(chuàng)建自定義框架估算器,以便將模型輸出歸類為 TensorFlowModel。


創(chuàng)建基本 SageMaker 框架估算器的子類,將估算器的模型類型指定為 TensorFlow 模型。為此,我們指定了一個自定義 create_model 方法,該方法使用現(xiàn)有的 TensorFlowModel 類來啟動推理容器。
3、通過代碼倉庫訓(xùn)練估算器的腳本。
您可以通過傳統(tǒng)的 SageMaker Python SDK 工作流通過模型執(zhí)行訓(xùn)練、部署和運(yùn)行推理。我們確保導(dǎo)入并初始化自定義框架估算器的代碼片段中定義的 JaxEstimator,然后運(yùn)行標(biāo)準(zhǔn)的 .fit () 和 .deploy () 調(diào)用。

對于 JAX ,可以調(diào)用 jax2tf 函數(shù)來執(zhí)行相同的操作。代碼在存儲庫中可用。設(shè)置正確的路徑 /opt/ml/model/1 非常重要,這是 SageMaker wrapper(封裝器) 假定模型已存儲的地方。、

前面提到的 JAX 和 TF 的互操作性,目前 JAX 是通過 JAX to TF 這樣的一個軟件包,來為 JAX 和 TF 的互操作性提供支持,那 jax2tf.convert 是用于在 TensorFlow 的上下文中使用 JAX 函數(shù),那 jax2tf.call_tf 是用于在 JAX 的上下文中使用的 TensorFlow 函數(shù)互操作來完成的。
4、使用 GPU 上的 SageMaker 訓(xùn)練作業(yè)來訓(xùn)練每個模型。
將模型部署到完全托管的終端節(jié)點。
vanilla_jax_predictor = vanilla_jax_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge"
)
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
def test_image(predictor, test_images, test_labels, image_number):np_img
= np.expand_dims(np.expand_dims(test_images[image_number], axis=
-1
), axis=
0
)result = predictor.predict(np_img)pred_y = np.argmax(result["predictions"])print("True Label:", test_labels[image_number])print("Predicted Label:", pred_y)plt.imshow(test_images[image_number]
*部署和準(zhǔn)備輸入的測試圖像

*進(jìn)行推理
有關(guān)在 Amazon SageMaker 上使用 JAX 訓(xùn)練和部署深度學(xué)習(xí)模型的詳細(xì)過程和代碼,請參考亞馬遜云科技官方博客。
如圖所示,上面的兩張圖是一個部署模型的例子,下面的圖是進(jìn)行推理的例子。由于我們的 Framework Estimator 知道模型將使用 TensorFlowModel 提供服務(wù),因此部署這些端點只是對 estimator.deploy () 方法做調(diào)用即可。
參考資料
Training and Deploying ML Models using JAX on SageMaker
Train and deploy deep learning models using JAX with Amazon SageMaker
AX core from scratch
Building JAX from source
JAX 是一種越來越流行的庫,它支持原生 Python 或 NumPy 函數(shù)的可組合函數(shù)轉(zhuǎn)換,可用于高性能數(shù)值計算和機(jī)器學(xué)習(xí)研究。JAX 提供了編寫 NumPy 程序的能力,這些程序可以使用 GPU/TPU 自動差分和加速,從而形成了更靈活的框架來支持現(xiàn)代深度學(xué)習(xí)架構(gòu)。在這兩篇文章中我們討論了有關(guān) JAX 的一些主題,希望對您用使用 JAX 這一框架進(jìn)行深度學(xué)習(xí)研究有所幫助。
往期推薦
機(jī)器學(xué)習(xí)洞察 | JAX,機(jī)器學(xué)習(xí)領(lǐng)域的“新面孔”
機(jī)器學(xué)習(xí)洞察 | 降本增效,無服務(wù)器推理是怎么做到的?
機(jī)器學(xué)習(xí)洞察 | 分布式訓(xùn)練讓機(jī)器學(xué)習(xí)更加快速準(zhǔn)確
