建筑專業(yè)網(wǎng)站文明seo技術(shù)教程網(wǎng)
模型部署&推理
- 模型部署
- 模型推理
我們會將PyTorch訓(xùn)練好的模型轉(zhuǎn)換為ONNX 格式,然后使用ONNX Runtime運行它進行推理
1、ONNX
ONNX( Open Neural Network Exchange) 是 Facebook (現(xiàn)Meta) 和微軟在2017年共同發(fā)布的,用于標(biāo)準(zhǔn)描述計算圖的一種格式。ONNX通過定義一組與環(huán)境和平臺無關(guān)的標(biāo)準(zhǔn)格式,使AI模型可以在不同框架和環(huán)境下交互使用,ONNX可以看作深度學(xué)習(xí)框架和部署端的橋梁,就像編譯器的中間語言一樣
由于各框架兼容性不一,我們通常只用 ONNX 表示更容易部署的靜態(tài)圖。硬件和軟件廠商只需要基于ONNX標(biāo)準(zhǔn)優(yōu)化模型性能,讓所有兼容ONNX標(biāo)準(zhǔn)的框架受益
ONNX主要關(guān)注在模型預(yù)測方面,使用不同框架訓(xùn)練的模型,轉(zhuǎn)化為ONNX格式后,可以很容易的部署在兼容ONNX的運行環(huán)境中
- ONNX官網(wǎng):https://onnx.ai/
- ONNX GitHub:https://github.com/onnx/onnx
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-4hoUBZ88-1692614464568)(attachment:image-2.png)]
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-PlCTmLyk-1692614464569)(attachment:image.png)]
2、ONNX Runtime
- ONNX Runtime官網(wǎng):https://www.onnxruntime.ai/
- ONNX Runtime GitHub:https://github.com/microsoft/onnxruntime
ONNX Runtime 是由微軟維護的一個跨平臺機器學(xué)習(xí)推理加速器,它直接對接ONNX,可以直接讀取.onnx文件并實現(xiàn)推理,不需要再把 .onnx 格式的文件轉(zhuǎn)換成其他格式的文件
PyTorch借助ONNX Runtime也完成了部署的最后一公里,構(gòu)建了 PyTorch --> ONNX --> ONNX Runtime 部署流水線
安裝onnx
pip install onnx
安裝onnx runtime
pip install onnxruntime # 使用CPU進行推理
pip install onnxruntime-gpu # 使用GPU進行推理
注意:ONNX和ONNX Runtime之間的適配關(guān)系。我們可以訪問ONNX Runtime的Github進行查看
網(wǎng)址:https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-NVBVlhGG-1692614464569)(attachment:image.png)]
ONNX Runtime和CUDA之間的適配關(guān)系
網(wǎng)址:https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-6x0xvNMn-1692614464569)(attachment:image-2.png)]
ONNX Runtime、TensorRT和CUDA的匹配關(guān)系:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-G7NPCXmY-1692614464569)(attachment:image-3.png)]
3、模型轉(zhuǎn)換為ONNX格式
- 用torch.onnx.export()把模型轉(zhuǎn)換成 ONNX 格式的函數(shù)
- 模型導(dǎo)成onnx格式前,我們必須調(diào)用model.eval()或者model.train(False)以確保我們的模型處在推理模式下
import torch.onnx
# 轉(zhuǎn)換的onnx格式的名稱,文件后綴需為.onnx
onnx_file_name = "resnet50.onnx"
# 我們需要轉(zhuǎn)換的模型,將torch_model設(shè)置為自己的模型
model = torchvision.models.resnet50(pretrained=True)
# 加載權(quán)重,將model.pth轉(zhuǎn)換為自己的模型權(quán)重
model = model.load_state_dict(torch.load("resnet50.pt"))
# 導(dǎo)出模型前,必須調(diào)用model.eval()或者model.train(False)
model.eval()
# dummy_input就是一個輸入的實例,僅提供輸入shape、type等信息
batch_size = 1 # 隨機的取值,當(dāng)設(shè)置dynamic_axes后影響不大
dummy_input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
# 這組輸入對應(yīng)的模型輸出
output = model(dummy_input)
# 導(dǎo)出模型
torch.onnx.export(model, # 模型的名稱dummy_input, # 一組實例化輸入onnx_file_name, # 文件保存路徑/名稱export_params=True, # 如果指定為True或默認(rèn), 參數(shù)也會被導(dǎo)出. 如果你要導(dǎo)出一個沒訓(xùn)練過的就設(shè)為 False.opset_version=10, # ONNX 算子集的版本,當(dāng)前已更新到15do_constant_folding=True, # 是否執(zhí)行常量折疊優(yōu)化input_names = ['conv1'], # 輸入模型的張量的名稱output_names = ['fc'], # 輸出模型的張量的名稱# dynamic_axes將batch_size的維度指定為動態(tài),# 后續(xù)進行推理的數(shù)據(jù)可以與導(dǎo)出的dummy_input的batch_size不同dynamic_axes={'conv1' : {0 : 'batch_size'}, 'fc' : {0 : 'batch_size'}})
注:
算子版本對照文檔:https://github.com/onnx/onnx/blob/main/docs/Operators.md
ONNX模型的檢驗
我們需要檢測下我們的模型文件是否可用,我們將通過onnx.checker.check_model()進行檢驗
import onnx
# 我們可以使用異常處理的方法進行檢驗
try:# 當(dāng)我們的模型不可用時,將會報出異常onnx.checker.check_model(self.onnx_model)
except onnx.checker.ValidationError as e:print("The model is invalid: %s"%e)
else:# 模型可用時,將不會報出異常,并會輸出“The model is valid!”print("The model is valid!")
ONNX模型可視化
使用netron做可視化。下載地址:https://netron.app/
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-iEgN86DI-1692614464569)(attachment:image.png)]
模型的輸入&輸出信息:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-qzyKV8ba-1692614464570)(attachment:image-2.png)]
使用ONNX Runtime進行推理
import onnxruntime
# 需要進行推理的onnx模型文件名稱
onnx_file_name = "xxxxxx.onnx"# onnxruntime.InferenceSession用于獲取一個 ONNX Runtime 推理器
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=['CPUExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['CUDAExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['OpenVINOExecutionProvider'])# 構(gòu)建字典的輸入數(shù)據(jù),字典的key需要與我們構(gòu)建onnx模型時的input_names相同
# 輸入的input_img 也需要改變?yōu)閚darray格式
# ort_inputs = {'conv_1': input_img}
#建議使用下面這種方法,因為避免了手動輸入key
ort_inputs = {ort_session.get_inputs()[0].name:input_img}# run是進行模型的推理,第一個參數(shù)為輸出張量名的列表,一般情況可以設(shè)置為None
# 第二個參數(shù)為構(gòu)建的輸入值的字典
# 由于返回的結(jié)果被列表嵌套,因此我們需要進行[0]的索引
ort_output = ort_session.run(None,ort_inputs)[0]
# output = {ort_session.get_outputs()[0].name}
# ort_output = ort_session.run([output], ort_inputs)[0]
注意:
- PyTorch模型的輸入為tensor,而ONNX的輸入為array,因此我們需要對張量進行變換或者直接將數(shù)據(jù)讀取為array格式
- 輸入的array的shape應(yīng)該和我們導(dǎo)出模型的dummy_input的shape相同,如果圖片大小不一樣,我們應(yīng)該先進行resize操作
- run的結(jié)果是一個列表,我們需要進行索引操作才能獲得array格式的結(jié)果
- 在構(gòu)建輸入的字典時,我們需要注意字典的key應(yīng)與導(dǎo)出ONNX格式設(shè)置的input_name相同
完整代碼
1. 安裝&下載
#!pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
#!pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
#!pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple
# Download ImageNet labels
#!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
2、定義模型
import torch
import io
import time
from PIL import Image
import torchvision.transforms as transforms
from torchvision import datasets
import onnx
import onnxruntime
import torchvision
import numpy as np
from torch import nn
import torch.nn.init as init
onnx_file = 'resnet50.onnx'
save_dir = './resnet50.pt'
# 下載預(yù)訓(xùn)練模型
Resnet50 = torchvision.models.resnet50(pretrained=True)# 保存 模型權(quán)重
torch.save(Resnet50.state_dict(), save_dir)print(Resnet50)
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer2): Sequential((0): Bottleneck((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer3): Sequential((0): Bottleneck((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(4): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(5): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer4): Sequential((0): Bottleneck((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=1000, bias=True)
)
3. 模型導(dǎo)出為ONNX格式
batch_size = 1 # just a random number
# 先加載模型結(jié)構(gòu)
loaded_model = torchvision.models.resnet50()
# 在加載模型權(quán)重
loaded_model.load_state_dict(torch.load(save_dir))
#單卡GPU
# loaded_model.cuda()# 將模型設(shè)置為推理模式
loaded_model.eval()
# Input to the model
x = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
torch_out = loaded_model(x)
torch_out
tensor([[-5.8050e-01, 7.5065e-02, 1.9404e-01, -9.1107e-01, 9.9716e-01,-1.2941e+00, -1.3402e-01, -6.4496e-01, 6.0434e-01, -1.6355e+00,-1.5187e-01, 1.0285e+00, -9.0719e-02, -2.6877e-01, -1.2656e+00,-7.9748e-01, -1.3802e+00, -9.6179e-01, 5.3512e-01, 8.3388e-02,-6.2868e-01, 1.5385e-01, -2.5405e-01, 4.3549e-01, -3.2834e-02,-8.9873e-01, -1.7059e+00, -8.5661e-01, -1.4386e+00, -2.0589e+00,-2.3464e+00, -3.6227e-01, -3.5712e+00, -1.6644e+00, -3.0064e-01,-1.8671e+00, 7.5745e-01, -2.3606e+00, 1.2460e-01, 2.7504e-01,-2.1071e-01, -2.6051e+00, 4.9932e-02, -3.0857e-01, -1.5757e-02,5.6365e-02, 1.0149e-01, -2.4776e+00, 1.7863e+00, -2.1650e+00,1.8615e+00, -2.8109e+00, -2.0084e+00, -5.4413e-01, 8.8444e-01,-8.8331e-01, 7.3980e-02, -2.0061e+00, 5.5653e-01, 7.1335e-01,4.6456e-01, 1.0112e+00, 4.2683e-01, -1.8685e-01, -1.1910e+00,1.6901e-01, -7.3501e-01, -2.4989e-01, -2.7711e-01, 1.8286e+00,-1.1317e+00, 1.9985e+00, 4.0941e-01, 2.7733e-01, -5.1216e-02,3.1703e-01, -2.1450e-01, 1.5035e+00, 1.2469e+00, 3.6729e+00,-1.2205e+00, -2.9484e-01, -3.2170e-01, -2.1006e+00, -1.2326e-01,3.9842e-01, -3.5075e-01, 1.5957e-01, -4.8100e-01, 1.2830e+00,-1.1557e+00, 2.9266e-01, 6.7955e-01, 1.2951e+00, -1.7461e-01,-3.4974e+00, 9.8954e-01, -1.1453e+00, -1.5246e+00, 7.6012e-01,-2.7971e-01, -1.0384e-01, -1.3282e+00, 3.7075e-01, -1.0879e+00,-2.2167e+00, -1.6805e+00, 1.5793e-01, -1.2778e+00, -3.4896e-01,6.2826e-01, 1.7638e+00, -8.2627e-01, 6.5328e-01, 5.1948e-01,-1.5375e+00, -2.7378e+00, -6.8703e-02, -1.5729e+00, -2.1919e+00,-1.0581e+00, -2.9345e+00, -3.2737e+00, -2.5095e+00, -2.5462e+00,-3.4298e+00, 1.0801e+00, -4.6679e-02, -7.1422e-01, -1.1388e+00,-2.2512e+00, -9.3222e-01, 2.7792e-01, -2.4730e-01, -1.3677e+00,-1.1018e+00, -2.3430e+00, 1.1828e+00, 1.5632e+00, -2.6486e+00,-2.2285e+00, -8.2680e-01, -1.9754e+00, -1.5034e+00, -2.1048e+00,1.0566e+00, -6.0091e-01, -2.2394e+00, -1.0461e+00, -1.4851e+00,9.9063e-02, 4.5648e-01, -3.0590e+00, -5.1038e-02, -2.2756e+00,-1.5584e+00, -2.6344e+00, -1.3177e+00, -2.4749e+00, 1.3347e-01,-1.8447e+00, -1.9380e+00, -1.1397e+00, -9.6618e-01, -4.7473e-01,-8.1531e-01, -2.0591e+00, -2.2707e+00, -2.1579e+00, -8.4820e-01,-1.8621e+00, -1.0359e+00, -1.7589e+00, -5.1326e-01, -1.9336e+00,-2.4361e+00, -3.0598e+00, -1.5690e+00, 7.9418e-01, -2.0329e+00,-1.4686e+00, -1.3989e+00, -1.2050e+00, -4.6212e-01, -2.1246e+00,3.9028e-02, -1.3888e+00, -8.1794e-01, -3.2460e+00, -2.9345e-01,-1.5963e+00, -1.4708e+00, -1.7513e+00, -1.0326e+00, -2.5880e+00,-3.5845e-02, -1.8802e+00, -2.0279e+00, -2.2119e+00, -5.6981e-01,-1.4423e+00, -5.3841e-01, -2.4736e-01, 1.4031e-01, -1.1382e+00,-1.3424e+00, -1.5412e-01, -1.5119e+00, -8.1195e-01, -2.3688e+00,-3.1494e+00, -1.2997e+00, -2.0867e+00, -1.5811e+00, -1.1873e+00,-1.4610e+00, 4.6883e-01, -1.3841e+00, -2.3627e+00, -5.0272e-01,-2.2311e+00, 2.8236e-01, -1.4063e+00, -6.1543e-01, 2.2254e-01,-1.8209e+00, -2.2796e+00, -1.4799e+00, -9.3366e-01, -4.5269e-01,-1.5885e+00, -3.5685e-01, -7.9922e-01, -1.7434e+00, -1.3543e+00,-5.9424e-01, -7.4004e-02, -4.8574e-01, -9.4252e-01, -1.1784e+00,-1.0762e+00, -7.0929e-01, -2.3507e+00, -1.5668e+00, -2.8629e+00,-9.7854e-01, -7.7075e-01, -2.1660e+00, -2.3006e-01, -6.7149e-01,-8.6158e-01, -1.7104e-02, -1.9825e+00, -7.7517e-01, -3.8014e-01,-2.1186e+00, -9.2220e-01, -9.2850e-01, -1.2418e+00, 9.7522e-02,-3.6667e-03, -2.1291e+00, -2.8809e+00, -1.3699e+00, -1.5959e+00,-6.5653e-01, -1.2664e+00, -2.8341e-01, -1.5526e+00, -7.1795e-01,-4.8103e-01, -1.6648e+00, -8.2810e-01, -1.6934e+00, -1.3563e+00,-1.6123e+00, -1.1855e+00, -1.2475e+00, -1.3781e+00, -9.8912e-01,-1.3062e-03, 1.2144e+00, 2.8563e+00, 1.7405e+00, 3.0779e-01,8.2037e-01, -4.7336e-01, -2.7651e+00, 4.0167e-01, 2.1637e-01,-5.0109e-01, -1.0902e+00, -2.6263e-01, 5.9031e-01, -5.2879e-01,1.0321e+00, 1.2048e+00, 1.6882e-01, 4.2126e-02, -3.8657e-01,-1.3633e+00, 2.0077e+00, -9.9282e-01, -1.6829e-01, -1.5846e+00,-2.1892e+00, -6.6651e-01, 9.6200e-01, 1.1047e+00, -3.3428e-01,2.7981e+00, 7.2582e-01, 3.4494e-01, 8.2232e-01, 1.7219e+00,1.0106e+00, -2.3200e-01, 4.9711e-02, 1.6123e+00, 8.3826e-01,-1.4559e+00, -2.4328e+00, -2.8555e+00, -2.6156e+00, -1.9900e+00,-2.4778e+00, -1.9356e+00, -1.5563e+00, -2.5033e+00, -3.5848e+00,-2.4205e-01, -5.5758e-01, 2.3322e-01, -1.1810e+00, -8.3212e-01,-4.8195e-02, -4.9411e-01, -3.0698e-03, -1.6134e+00, -1.5790e+00,-5.8626e-01, -1.8875e+00, -1.5670e+00, -2.0681e+00, -1.7590e+00,-3.9325e-01, -2.0172e+00, -1.3237e+00, -1.7693e-01, -8.5266e-01,-2.0535e+00, -2.7916e+00, -1.7173e+00, 5.3713e-02, -1.9363e-01,-3.1787e-01, 7.0567e-01, 5.3067e-01, 1.0458e+00, 1.2243e+00,-3.9257e-01, -3.9865e-01, 3.8122e-01, 3.4527e-01, -1.6836e+00,6.8797e-01, 1.2213e+00, 1.0733e+00, 1.1278e+00, 6.7682e-01,1.2179e+00, -8.0824e-01, 2.7535e-03, -8.5098e-01, -9.4244e-02,-3.7395e-01, -5.9386e-01, -8.1263e-02, -5.8865e-01, -8.3479e-01,-7.2452e-01, -1.6460e-01, 7.2182e-01, 1.2066e+00, -1.8087e+00,-4.4841e-01, -3.2795e-01, -3.0482e-01, -3.3302e-01, -2.4936e+00,-5.7049e-01, -2.0744e-02, -7.5551e-01, -2.4757e+00, -1.7799e+00,-1.1292e+00, -1.0917e+00, 6.8229e-01, 8.7337e-01, 3.1813e+00,-1.5752e+00, 1.0542e-01, 2.5594e+00, -1.0048e+00, -2.2436e+00,4.9551e-01, -2.0745e+00, -9.9214e-01, -2.5501e+00, 2.7392e+00,6.4982e-01, 3.5795e+00, 2.0882e+00, 1.0579e+00, 2.3663e+00,-1.1029e+00, -6.6217e-01, -4.8396e-01, 3.6624e+00, 2.3802e+00,8.2251e-01, 2.5061e+00, -1.8793e+00, 1.6354e+00, 1.9349e+00,7.7006e-01, 2.4251e-01, 1.7568e+00, -9.3206e-01, 1.2631e+00,1.0240e+00, -3.5013e-01, 7.5377e-03, 5.0503e-01, -9.5431e-01,1.5458e+00, -2.5770e+00, 5.7188e-01, 9.7471e-01, -3.1393e-01,1.0891e+00, 2.3057e+00, -7.5324e-01, 3.2789e+00, -8.1716e-01,-1.9879e+00, 5.5330e+00, 6.3507e-01, -1.1635e+00, -1.1235e+00,-3.4298e+00, 7.5610e-01, -3.1293e-02, -9.6185e-01, -8.1488e-02,1.1240e+00, -6.9891e-02, 2.5587e+00, 2.2736e+00, 1.7838e-01,-6.9245e-01, 2.4419e+00, 2.0427e+00, 1.1029e+00, 4.1609e+00,3.5126e+00, -1.8192e+00, -3.3070e+00, 7.6861e-01, 1.2807e+00,2.1298e-01, -8.7622e-01, -2.1935e+00, 1.0431e+00, 1.9949e+00,-3.2491e-01, -3.1093e+00, -1.0409e+00, 1.2334e+00, -1.7676e-01,3.0567e+00, 2.6081e+00, 2.7356e-01, 6.0596e-02, -1.3262e+00,-3.5291e-01, -4.7318e-01, 2.1949e+00, 5.3661e+00, 4.2932e+00,8.3733e+00, 4.1425e-01, 2.4924e-01, -1.3689e+00, 7.1289e-02,-9.8287e-01, -1.2412e+00, 1.3910e+00, 1.9533e+00, 3.3525e+00,1.7242e+00, 1.7637e+00, 1.0108e+00, 1.2255e+00, 1.7504e+00,5.4399e-01, 2.2958e+00, 1.9387e+00, 2.4723e+00, -1.1986e+00,-1.5123e+00, -1.9842e+00, 1.8934e+00, 1.3407e+00, 4.6350e-01,2.6674e+00, 1.0492e+00, 1.0988e+00, -1.4208e-02, 3.9129e-01,-4.7343e-01, -1.7139e+00, -7.8037e-01, 1.3938e+00, 2.4655e+00,-9.8006e-01, -5.5273e-01, 1.1947e+00, 1.5285e+00, 2.2214e-01,2.2346e+00, 1.3524e+00, -3.2841e-01, 2.1160e+00, 4.4156e+00,-2.7112e+00, -9.0547e-01, -1.4378e+00, 1.5687e+00, 3.1633e+00,-2.9853e-01, 1.2451e+00, 2.5149e+00, 1.0312e+00, -6.9518e-01,1.1537e+00, 9.6612e-01, -3.5077e+00, -7.9979e-02, 4.3770e+00,-6.3443e-01, -5.2904e-01, 1.5411e+00, 1.2678e+00, -1.2136e+00,-2.1303e+00, 5.5227e+00, 3.5111e-01, 1.5474e+00, 2.1807e+00,1.4828e+00, -1.4299e+00, 1.9229e+00, 2.4931e+00, -2.5156e+00,-1.7203e+00, -4.2708e-01, 1.6891e+00, 1.5878e+00, -3.3333e+00,2.1083e+00, -1.7954e-02, 3.9262e-01, -1.8340e+00, 7.8696e-01,-2.9308e+00, -2.3592e+00, 1.0347e+00, 8.9930e-01, 1.2392e+00,5.4734e-01, 6.6852e-01, -2.6781e+00, 2.2405e-01, -9.0210e-01,1.0648e+00, -2.3832e+00, 1.7305e+00, 1.6958e+00, 1.0681e+00,8.2608e-01, 2.5071e+00, -2.3054e-01, 3.9594e-01, -1.4630e-01,-2.1682e+00, 3.0358e+00, 1.5096e+00, 7.6303e-01, 4.4392e+00,3.2750e+00, 2.6279e+00, 4.3440e-01, -3.9379e+00, 1.0872e+00,1.7172e+00, 2.8548e+00, -1.0287e+00, 4.9895e+00, -2.0666e+00,4.8006e+00, 2.0120e+00, -1.5181e+00, 8.6181e-01, -3.4666e-01,2.2120e+00, 3.0910e+00, 5.9223e-01, 2.2166e+00, 3.9417e+00,3.5241e+00, -5.3305e-01, 3.5832e+00, 2.5654e+00, -1.5450e+00,-2.6835e+00, 3.1550e+00, -2.6302e+00, 2.3621e-01, 2.1758e+00,1.2487e+00, -1.0268e-01, 3.6262e+00, 3.6049e+00, -2.3248e+00,2.3213e-01, 3.2931e+00, -1.0058e+00, 4.5938e-01, -4.2993e-01,1.3951e+00, -2.8811e-01, -5.2850e-01, 1.0776e+00, 4.6138e+00,-7.1348e-01, 5.8099e-01, 4.4438e-01, -6.0801e-01, 7.0509e-01,3.5084e+00, 3.0626e+00, 7.0831e-01, 1.5073e+00, -2.1074e+00,3.2849e+00, -2.7267e+00, 2.9387e-01, 5.1394e-01, 1.4031e-01,-1.0694e+00, -2.5526e+00, 1.6833e+00, -1.3013e+00, 3.0083e+00,-1.9390e+00, 4.4978e-01, -1.5059e-01, -2.4490e+00, 1.6431e+00,-4.6816e-01, -1.6293e+00, -7.9092e-01, 1.1116e+00, 2.1265e+00,-3.0442e+00, 9.5523e-02, 2.8034e+00, 1.3312e+00, 3.4422e+00,4.4743e-01, 1.7062e+00, 1.8941e-01, 1.2406e+00, -9.8100e-01,-9.7636e-01, -3.9718e-01, -5.6298e-01, 2.1325e+00, 1.4298e+00,-4.6180e+00, -5.8675e-01, 1.7124e+00, -7.3919e-02, -2.9715e+00,2.9501e+00, 1.4472e+00, -1.3756e+00, -1.0018e+00, -1.1162e-01,1.2214e+00, -5.2164e-01, -8.7681e-01, 6.0252e-01, 2.7381e-01,-2.9817e+00, -1.3999e+00, 1.8137e+00, -3.4810e-02, 1.2475e+00,-5.1820e-01, 3.4469e+00, 2.8484e+00, 5.9049e-01, 2.2143e+00,-1.9403e-01, 1.5231e+00, -4.1188e+00, 5.6471e-01, -1.4212e+00,1.1938e+00, 2.8821e+00, 2.4709e+00, -1.6792e+00, -4.7604e-01,1.7501e+00, -2.2566e+00, 7.4556e-01, 2.5034e+00, -3.6194e-01,-1.1058e+00, 2.2076e+00, -6.0705e-03, 2.5470e+00, -1.9637e+00,2.7231e+00, 2.4390e+00, 1.1190e+00, -9.0371e-01, -4.4400e-01,8.6673e-01, 2.8887e+00, -6.5289e-01, 1.6986e+00, 6.0122e-01,-1.1510e+00, 1.9672e+00, 3.6989e+00, 1.3653e-01, 9.0087e-01,1.8489e+00, -2.7983e+00, 1.5802e+00, 2.6502e+00, 1.1414e+00,-5.3817e-01, 1.1085e+00, -2.1715e+00, -7.2016e-01, 1.5999e+00,4.9543e+00, 1.9814e+00, -1.1679e+00, 2.8527e+00, 2.1758e+00,7.5756e-01, -1.0221e+00, 1.2118e+00, -2.4591e-01, 1.4493e+00,3.4529e-01, 1.6389e+00, 4.0479e+00, 1.2619e+00, 4.2199e-01,-1.2010e+00, 2.7446e+00, 3.2914e+00, 1.6454e+00, -4.8627e-01,-3.6592e-01, 1.1508e+00, 4.4760e+00, 3.3516e+00, 2.9289e+00,1.6571e+00, -6.9271e-02, 1.5371e+00, -1.6635e-01, 2.8581e+00,1.0374e+00, 1.1429e+00, 2.1297e+00, 1.0264e+00, 4.7174e+00,-8.5201e-01, 1.7106e+00, 7.4727e-01, 6.5346e-01, 1.6801e+00,-3.7609e-01, -1.5926e+00, -2.6283e+00, -1.6866e+00, 5.5250e-02,-6.2809e-02, 5.9573e-01, -7.4590e-01, 5.3049e-01, -1.5091e+00,-8.0366e-01, 3.3241e+00, 2.3141e+00, 1.1193e+00, -1.6830e+00,3.3035e+00, 2.9134e-01, -2.9930e+00, 2.4471e+00, 9.8725e-01,-2.7953e+00, -1.7308e+00, -9.4977e-01, 1.6247e-01, 2.5793e+00,2.9449e-01, 2.1876e+00, 1.3091e-01, 6.2929e+00, -5.5488e-01,1.2929e+00, -9.5095e-03, -1.1349e+00, -1.0178e-01, 2.3317e+00,-4.3678e-01, 2.3839e+00, 2.6191e+00, -2.0215e+00, 1.5188e+00,3.1490e+00, 3.1997e+00, -2.2047e-01, -1.2029e-01, 2.7171e+00,3.1623e+00, 7.7251e-01, -1.8028e+00, -7.3017e-01, 1.5781e+00,7.6143e-01, 4.7296e+00, 1.7691e+00, 1.4732e+00, 2.0614e+00,2.2509e+00, -4.4578e+00, 1.1764e+00, 2.2630e+00, 5.7318e-01,4.3310e-01, 1.6570e+00, -1.4352e+00, -1.2535e+00, -4.0429e+00,-5.1775e-01, -1.5580e+00, -1.8145e+00, 2.4469e+00, 1.9574e+00,-2.0032e-01, -2.0393e+00, 3.3668e+00, -5.2449e-01, -4.5653e+00,4.8361e-01, 4.8011e-01, 8.3248e-01, -1.4842e-01, 2.5230e+00,-3.1912e-01, 1.1091e+00, 1.9290e+00, 6.5501e-01, 7.5642e-01,1.3678e+00, 1.6187e+00, -2.2867e+00, -1.3338e+00, 7.0305e-01,-2.6969e+00, -3.4848e-01, 3.5779e+00, 2.5296e+00, 1.2646e+00,-8.2202e-01, 1.5727e+00, 2.0048e+00, 1.9939e+00, 3.6664e-01,-3.7189e-01, 6.5360e-02, 2.5970e+00, 1.9509e+00, 7.9060e+00,4.1564e+00, 1.9750e+00, 1.3692e+00, 7.0074e-01, 1.3194e+00,1.5737e+00, 3.1158e+00, 2.8220e-01, -1.1930e+00, -2.9132e+00,3.6715e-01, 2.0554e+00, -4.5951e-01, 1.4659e+00, 1.6097e-01,3.5082e-01, 1.9813e+00, 2.3234e+00, -1.6767e+00, -1.9703e+00,-4.2028e-01, -2.6262e+00, -1.3928e+00, -7.6662e-01, 4.5116e-01,2.6828e-01, -2.8156e-01, 7.0492e-02, -2.3663e+00, -5.0179e-01,-1.6241e-01, -2.5555e+00, -9.8973e-02, -2.2130e+00, -2.3067e+00,-1.8250e+00, -1.8571e+00, -2.4779e+00, -2.7528e+00, -2.9528e+00,-9.4892e-01, -2.8599e+00, -6.0309e-01, -1.4899e-01, -9.7413e-01,9.2476e-01, 1.2974e+00, -8.6647e-01, -1.4522e-01, 1.5039e+00,1.5240e-01, -1.9550e+00, -1.3404e+00, 5.6667e-01, -1.2009e+00,-9.4940e-01, 1.0278e+00, -2.9112e+00, -6.9027e-01, -8.4326e-01,-1.5937e+00, 1.6618e+00, 3.1860e+00, 3.0757e+00, 4.0690e-01,-1.1017e+00, 3.6284e+00, -6.9720e-01, -1.3498e+00, 1.4283e-01,-4.1820e-01, -1.6470e+00, 4.1369e-01, 1.7120e-01, -1.7615e+00,7.3642e-01, 1.7452e+00, 4.3359e-01, -2.8788e-01, -6.6571e-02,-1.4325e-02, -2.2441e+00, 1.2690e+00, -7.3996e-01, -1.1551e+00,-1.4367e+00, -1.5546e+00, -2.9878e+00, -3.5215e+00, -4.2169e+00,-3.7416e+00, -2.0244e+00, -2.6461e+00, -1.1108e+00, 1.1864e+00]],grad_fn=<AddmmBackward0>)
torch_out.size()
torch.Size([1, 1000])
# 導(dǎo)出模型
torch.onnx.export(loaded_model, # model being runx, # model input (or a tuple for multiple inputs)onnx_file, # where to save the model (can be a file or file-like object)export_params=True, # store the trained parameter weights inside the model fileopset_version=10, # the ONNX version to export the model todo_constant_folding=True, # whether to execute constant folding for optimizationinput_names = ['conv1'], # the model's input namesoutput_names = ['fc'], # the model's output names# variable length axesdynamic_axes={'conv1' : {0 : 'batch_size'}, 'fc' : {0 : 'batch_size'}})
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
4、檢驗ONNX模型
# 我們可以使用異常處理的方法進行檢驗
try:# 當(dāng)我們的模型不可用時,將會報出異常onnx.checker.check_model(onnx_file)
except onnx.checker.ValidationError as e:print("The model is invalid: %s"%e)
else:# 模型可用時,將不會報出異常,并會輸出“The model is valid!”print("The model is valid!")
The model is valid!
5. 使用ONNX Runtime進行推理
import onnxruntime
import numpy as nport_session = onnxruntime.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])# 將張量轉(zhuǎn)化為ndarray格式
def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 構(gòu)建輸入的字典和計算輸出結(jié)果
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)# 比較使用PyTorch和ONNX Runtime得出的精度
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)print("Exported model has been tested with ONNXRuntime, and the result looks good!")
Exported model has been tested with ONNXRuntime, and the result looks good!
6. 進行實際預(yù)測并可視化
# 推理數(shù)據(jù)
from PIL import Image
from torchvision.transforms import transforms# 生成推理圖片
image = Image.open('./images/cat.jpg')# 將圖像調(diào)整為指定大小
image = image.resize((224, 224))# 將圖像轉(zhuǎn)換為 RGB 模式
image = image.convert('RGB')image.save('./images/cat_224.jpg')
categories = []
# Read the categories
with open("./imagenet/imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]def get_class_name(probabilities):# Show top categories per imagetop5_prob, top5_catid = torch.topk(probabilities, 5)for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())
#預(yù)處理
def pre_image(image_file):input_image = Image.open(image_file)preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])input_tensor = preprocess(input_image)inputs = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model# input_arr = inputs.cpu().detach().numpy()return inputs
#inference with model# 先加載模型結(jié)構(gòu)
resnet50 = torchvision.models.resnet50()
# 在加載模型權(quán)重
resnet50.load_state_dict(torch.load(save_dir))resnet50.eval()
#推理
input_batch = pre_image('./images/cat_224.jpg')# move the input and model to GPU for speed if available
print("GPU Availability: ", torch.cuda.is_available())
if torch.cuda.is_available():input_batch = input_batch.to('cuda')resnet50.to('cuda')with torch.no_grad():output = resnet50(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
# print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
get_class_name(probabilities)
GPU Availability: False
Persian cat 0.6668420433998108
lynx 0.023987364023923874
bow tie 0.016234245151281357
hair slide 0.013150070793926716
Japanese spaniel 0.012279157526791096
input_batch.size()
torch.Size([1, 3, 224, 224])
#benchmark 性能
latency = []
for i in range(10):with torch.no_grad():start = time.time()output = resnet50(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0)top5_prob, top5_catid = torch.topk(probabilities, 5)# for catid in range(top5_catid.size(0)):# print(categories[catid])latency.append(time.time() - start)print("{} model inference CPU time:cost {} ms".format(str(i),format(sum(latency) * 1000 / len(latency), '.2f')))
0 model inference CPU time:cost 149.59 ms
1 model inference CPU time:cost 130.74 ms
2 model inference CPU time:cost 133.76 ms
3 model inference CPU time:cost 130.64 ms
4 model inference CPU time:cost 131.72 ms
5 model inference CPU time:cost 130.88 ms
6 model inference CPU time:cost 136.31 ms
7 model inference CPU time:cost 139.95 ms
8 model inference CPU time:cost 141.90 ms
9 model inference CPU time:cost 140.96 ms
# Inference with ONNX Runtime
import onnxruntime
from onnx import numpy_helper
import time
onnx_file = 'resnet50.onnx'
session_fp32 = onnxruntime.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['CUDAExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['OpenVINOExecutionProvider'])def softmax(x):"""Compute softmax values for each sets of scores in x."""e_x = np.exp(x - np.max(x))return e_x / e_x.sum()latency = []
def run_sample(session, categories, inputs):start = time.time()input_arr = inputsort_outputs = session.run([], {'conv1':input_arr})[0]output = ort_outputs.flatten()output = softmax(output) # this is optionaltop5_catid = np.argsort(-output)[:5]# for catid in top5_catid:# print(categories[catid])latency.append(time.time() - start)return ort_outputs
input_tensor = pre_image('./images/cat_224.jpg')
input_arr = input_tensor.cpu().detach().numpy()
for i in range(10):ort_output = run_sample(session_fp32, categories, input_arr)print("{} ONNX Runtime CPU Inference time = {} ms".format(str(i),format(sum(latency) * 1000 / len(latency), '.2f')))
0 ONNX Runtime CPU Inference time = 67.66 ms
1 ONNX Runtime CPU Inference time = 56.30 ms
2 ONNX Runtime CPU Inference time = 53.90 ms
3 ONNX Runtime CPU Inference time = 58.18 ms
4 ONNX Runtime CPU Inference time = 64.53 ms
5 ONNX Runtime CPU Inference time = 62.79 ms
6 ONNX Runtime CPU Inference time = 61.75 ms
7 ONNX Runtime CPU Inference time = 60.51 ms
8 ONNX Runtime CPU Inference time = 59.35 ms
9 ONNX Runtime CPU Inference time = 57.57 ms
4、擴展知識
- 模型量化
- 模型剪裁
- 工程優(yōu)化
- 算子優(yōu)化