国产亚洲精品福利在线无卡一,国产精久久一区二区三区,亚洲精品无码国模,精品久久久久久无码专区不卡

當(dāng)前位置: 首頁 > news >正文

建筑專業(yè)網(wǎng)站文明seo技術(shù)教程網(wǎng)

建筑專業(yè)網(wǎng)站,文明seo技術(shù)教程網(wǎng),武漢百度開戶代理,網(wǎng)站首頁圖片制作模型部署&推理 模型部署模型推理 我們會將PyTorch訓(xùn)練好的模型轉(zhuǎn)換為ONNX 格式,然后使用ONNX Runtime運行它進行推理 1、ONNX ONNX( Open Neural Network Exchange) 是 Facebook (現(xiàn)Meta) 和微軟在2017年共同發(fā)布的,用于標(biāo)準(zhǔn)描述計算圖的一種格式…

模型部署&推理

  • 模型部署
  • 模型推理

我們會將PyTorch訓(xùn)練好的模型轉(zhuǎn)換為ONNX 格式,然后使用ONNX Runtime運行它進行推理

1、ONNX

ONNX( Open Neural Network Exchange) 是 Facebook (現(xiàn)Meta) 和微軟在2017年共同發(fā)布的,用于標(biāo)準(zhǔn)描述計算圖的一種格式。ONNX通過定義一組與環(huán)境和平臺無關(guān)的標(biāo)準(zhǔn)格式,使AI模型可以在不同框架和環(huán)境下交互使用,ONNX可以看作深度學(xué)習(xí)框架和部署端的橋梁,就像編譯器的中間語言一樣

由于各框架兼容性不一,我們通常只用 ONNX 表示更容易部署的靜態(tài)圖。硬件和軟件廠商只需要基于ONNX標(biāo)準(zhǔn)優(yōu)化模型性能,讓所有兼容ONNX標(biāo)準(zhǔn)的框架受益

ONNX主要關(guān)注在模型預(yù)測方面,使用不同框架訓(xùn)練的模型,轉(zhuǎn)化為ONNX格式后,可以很容易的部署在兼容ONNX的運行環(huán)境中

  • ONNX官網(wǎng):https://onnx.ai/
  • ONNX GitHub:https://github.com/onnx/onnx

[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-4hoUBZ88-1692614464568)(attachment:image-2.png)]

[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-PlCTmLyk-1692614464569)(attachment:image.png)]

2、ONNX Runtime

  • ONNX Runtime官網(wǎng):https://www.onnxruntime.ai/
  • ONNX Runtime GitHub:https://github.com/microsoft/onnxruntime

ONNX Runtime 是由微軟維護的一個跨平臺機器學(xué)習(xí)推理加速器,它直接對接ONNX,可以直接讀取.onnx文件并實現(xiàn)推理,不需要再把 .onnx 格式的文件轉(zhuǎn)換成其他格式的文件

PyTorch借助ONNX Runtime也完成了部署的最后一公里,構(gòu)建了 PyTorch --> ONNX --> ONNX Runtime 部署流水線

安裝onnx

pip install onnx

安裝onnx runtime

pip install onnxruntime # 使用CPU進行推理

pip install onnxruntime-gpu # 使用GPU進行推理

注意:ONNX和ONNX Runtime之間的適配關(guān)系。我們可以訪問ONNX Runtime的Github進行查看

網(wǎng)址:https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md

[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-NVBVlhGG-1692614464569)(attachment:image.png)]

ONNX Runtime和CUDA之間的適配關(guān)系

網(wǎng)址:https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-6x0xvNMn-1692614464569)(attachment:image-2.png)]

ONNX Runtime、TensorRT和CUDA的匹配關(guān)系:

[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-G7NPCXmY-1692614464569)(attachment:image-3.png)]

3、模型轉(zhuǎn)換為ONNX格式

  • 用torch.onnx.export()把模型轉(zhuǎn)換成 ONNX 格式的函數(shù)
  • 模型導(dǎo)成onnx格式前,我們必須調(diào)用model.eval()或者model.train(False)以確保我們的模型處在推理模式下
import torch.onnx 
# 轉(zhuǎn)換的onnx格式的名稱,文件后綴需為.onnx
onnx_file_name = "resnet50.onnx"
# 我們需要轉(zhuǎn)換的模型,將torch_model設(shè)置為自己的模型
model = torchvision.models.resnet50(pretrained=True)
# 加載權(quán)重,將model.pth轉(zhuǎn)換為自己的模型權(quán)重
model = model.load_state_dict(torch.load("resnet50.pt"))
# 導(dǎo)出模型前,必須調(diào)用model.eval()或者model.train(False)
model.eval()
# dummy_input就是一個輸入的實例,僅提供輸入shape、type等信息 
batch_size = 1 # 隨機的取值,當(dāng)設(shè)置dynamic_axes后影響不大
dummy_input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) 
# 這組輸入對應(yīng)的模型輸出
output = model(dummy_input)
# 導(dǎo)出模型
torch.onnx.export(model,        # 模型的名稱dummy_input,   # 一組實例化輸入onnx_file_name,   # 文件保存路徑/名稱export_params=True,        #  如果指定為True或默認(rèn), 參數(shù)也會被導(dǎo)出. 如果你要導(dǎo)出一個沒訓(xùn)練過的就設(shè)為 False.opset_version=10,          # ONNX 算子集的版本,當(dāng)前已更新到15do_constant_folding=True,  # 是否執(zhí)行常量折疊優(yōu)化input_names = ['conv1'],   # 輸入模型的張量的名稱output_names = ['fc'], # 輸出模型的張量的名稱# dynamic_axes將batch_size的維度指定為動態(tài),# 后續(xù)進行推理的數(shù)據(jù)可以與導(dǎo)出的dummy_input的batch_size不同dynamic_axes={'conv1' : {0 : 'batch_size'},    'fc' : {0 : 'batch_size'}})

注:
算子版本對照文檔:https://github.com/onnx/onnx/blob/main/docs/Operators.md

ONNX模型的檢驗

我們需要檢測下我們的模型文件是否可用,我們將通過onnx.checker.check_model()進行檢驗

import onnx
# 我們可以使用異常處理的方法進行檢驗
try:# 當(dāng)我們的模型不可用時,將會報出異常onnx.checker.check_model(self.onnx_model)
except onnx.checker.ValidationError as e:print("The model is invalid: %s"%e)
else:# 模型可用時,將不會報出異常,并會輸出“The model is valid!”print("The model is valid!")

ONNX模型可視化

使用netron做可視化。下載地址:https://netron.app/

[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-iEgN86DI-1692614464569)(attachment:image.png)]

模型的輸入&輸出信息:

[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-qzyKV8ba-1692614464570)(attachment:image-2.png)]

使用ONNX Runtime進行推理


import onnxruntime
# 需要進行推理的onnx模型文件名稱
onnx_file_name = "xxxxxx.onnx"# onnxruntime.InferenceSession用于獲取一個 ONNX Runtime 推理器
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=['CPUExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['CUDAExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['OpenVINOExecutionProvider'])# 構(gòu)建字典的輸入數(shù)據(jù),字典的key需要與我們構(gòu)建onnx模型時的input_names相同
# 輸入的input_img 也需要改變?yōu)閚darray格式
# ort_inputs = {'conv_1': input_img} 
#建議使用下面這種方法,因為避免了手動輸入key
ort_inputs = {ort_session.get_inputs()[0].name:input_img}# run是進行模型的推理,第一個參數(shù)為輸出張量名的列表,一般情況可以設(shè)置為None
# 第二個參數(shù)為構(gòu)建的輸入值的字典
# 由于返回的結(jié)果被列表嵌套,因此我們需要進行[0]的索引
ort_output = ort_session.run(None,ort_inputs)[0]
# output = {ort_session.get_outputs()[0].name}
# ort_output = ort_session.run([output], ort_inputs)[0]

注意:

  • PyTorch模型的輸入為tensor,而ONNX的輸入為array,因此我們需要對張量進行變換或者直接將數(shù)據(jù)讀取為array格式
  • 輸入的array的shape應(yīng)該和我們導(dǎo)出模型的dummy_input的shape相同,如果圖片大小不一樣,我們應(yīng)該先進行resize操作
  • run的結(jié)果是一個列表,我們需要進行索引操作才能獲得array格式的結(jié)果
  • 在構(gòu)建輸入的字典時,我們需要注意字典的key應(yīng)與導(dǎo)出ONNX格式設(shè)置的input_name相同

完整代碼

1. 安裝&下載

#!pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
#!pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
#!pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple
# Download ImageNet labels
#!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

2、定義模型

import torch
import io
import time
from PIL import Image
import torchvision.transforms as transforms
from torchvision import datasets
import onnx
import onnxruntime
import torchvision
import numpy as np
from torch import nn
import torch.nn.init as init
onnx_file = 'resnet50.onnx'
save_dir = './resnet50.pt'

# 下載預(yù)訓(xùn)練模型
Resnet50 = torchvision.models.resnet50(pretrained=True)# 保存 模型權(quán)重
torch.save(Resnet50.state_dict(), save_dir)print(Resnet50)
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer2): Sequential((0): Bottleneck((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer3): Sequential((0): Bottleneck((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(4): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(5): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer4): Sequential((0): Bottleneck((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=1000, bias=True)
)

3. 模型導(dǎo)出為ONNX格式


batch_size = 1    # just a random number
# 先加載模型結(jié)構(gòu)
loaded_model = torchvision.models.resnet50()   
# 在加載模型權(quán)重
loaded_model.load_state_dict(torch.load(save_dir))
#單卡GPU
# loaded_model.cuda()# 將模型設(shè)置為推理模式
loaded_model.eval()
# Input to the model
x = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
torch_out = loaded_model(x)
torch_out
tensor([[-5.8050e-01,  7.5065e-02,  1.9404e-01, -9.1107e-01,  9.9716e-01,-1.2941e+00, -1.3402e-01, -6.4496e-01,  6.0434e-01, -1.6355e+00,-1.5187e-01,  1.0285e+00, -9.0719e-02, -2.6877e-01, -1.2656e+00,-7.9748e-01, -1.3802e+00, -9.6179e-01,  5.3512e-01,  8.3388e-02,-6.2868e-01,  1.5385e-01, -2.5405e-01,  4.3549e-01, -3.2834e-02,-8.9873e-01, -1.7059e+00, -8.5661e-01, -1.4386e+00, -2.0589e+00,-2.3464e+00, -3.6227e-01, -3.5712e+00, -1.6644e+00, -3.0064e-01,-1.8671e+00,  7.5745e-01, -2.3606e+00,  1.2460e-01,  2.7504e-01,-2.1071e-01, -2.6051e+00,  4.9932e-02, -3.0857e-01, -1.5757e-02,5.6365e-02,  1.0149e-01, -2.4776e+00,  1.7863e+00, -2.1650e+00,1.8615e+00, -2.8109e+00, -2.0084e+00, -5.4413e-01,  8.8444e-01,-8.8331e-01,  7.3980e-02, -2.0061e+00,  5.5653e-01,  7.1335e-01,4.6456e-01,  1.0112e+00,  4.2683e-01, -1.8685e-01, -1.1910e+00,1.6901e-01, -7.3501e-01, -2.4989e-01, -2.7711e-01,  1.8286e+00,-1.1317e+00,  1.9985e+00,  4.0941e-01,  2.7733e-01, -5.1216e-02,3.1703e-01, -2.1450e-01,  1.5035e+00,  1.2469e+00,  3.6729e+00,-1.2205e+00, -2.9484e-01, -3.2170e-01, -2.1006e+00, -1.2326e-01,3.9842e-01, -3.5075e-01,  1.5957e-01, -4.8100e-01,  1.2830e+00,-1.1557e+00,  2.9266e-01,  6.7955e-01,  1.2951e+00, -1.7461e-01,-3.4974e+00,  9.8954e-01, -1.1453e+00, -1.5246e+00,  7.6012e-01,-2.7971e-01, -1.0384e-01, -1.3282e+00,  3.7075e-01, -1.0879e+00,-2.2167e+00, -1.6805e+00,  1.5793e-01, -1.2778e+00, -3.4896e-01,6.2826e-01,  1.7638e+00, -8.2627e-01,  6.5328e-01,  5.1948e-01,-1.5375e+00, -2.7378e+00, -6.8703e-02, -1.5729e+00, -2.1919e+00,-1.0581e+00, -2.9345e+00, -3.2737e+00, -2.5095e+00, -2.5462e+00,-3.4298e+00,  1.0801e+00, -4.6679e-02, -7.1422e-01, -1.1388e+00,-2.2512e+00, -9.3222e-01,  2.7792e-01, -2.4730e-01, -1.3677e+00,-1.1018e+00, -2.3430e+00,  1.1828e+00,  1.5632e+00, -2.6486e+00,-2.2285e+00, -8.2680e-01, -1.9754e+00, -1.5034e+00, -2.1048e+00,1.0566e+00, -6.0091e-01, -2.2394e+00, -1.0461e+00, -1.4851e+00,9.9063e-02,  4.5648e-01, -3.0590e+00, -5.1038e-02, -2.2756e+00,-1.5584e+00, -2.6344e+00, -1.3177e+00, -2.4749e+00,  1.3347e-01,-1.8447e+00, -1.9380e+00, -1.1397e+00, -9.6618e-01, -4.7473e-01,-8.1531e-01, -2.0591e+00, -2.2707e+00, -2.1579e+00, -8.4820e-01,-1.8621e+00, -1.0359e+00, -1.7589e+00, -5.1326e-01, -1.9336e+00,-2.4361e+00, -3.0598e+00, -1.5690e+00,  7.9418e-01, -2.0329e+00,-1.4686e+00, -1.3989e+00, -1.2050e+00, -4.6212e-01, -2.1246e+00,3.9028e-02, -1.3888e+00, -8.1794e-01, -3.2460e+00, -2.9345e-01,-1.5963e+00, -1.4708e+00, -1.7513e+00, -1.0326e+00, -2.5880e+00,-3.5845e-02, -1.8802e+00, -2.0279e+00, -2.2119e+00, -5.6981e-01,-1.4423e+00, -5.3841e-01, -2.4736e-01,  1.4031e-01, -1.1382e+00,-1.3424e+00, -1.5412e-01, -1.5119e+00, -8.1195e-01, -2.3688e+00,-3.1494e+00, -1.2997e+00, -2.0867e+00, -1.5811e+00, -1.1873e+00,-1.4610e+00,  4.6883e-01, -1.3841e+00, -2.3627e+00, -5.0272e-01,-2.2311e+00,  2.8236e-01, -1.4063e+00, -6.1543e-01,  2.2254e-01,-1.8209e+00, -2.2796e+00, -1.4799e+00, -9.3366e-01, -4.5269e-01,-1.5885e+00, -3.5685e-01, -7.9922e-01, -1.7434e+00, -1.3543e+00,-5.9424e-01, -7.4004e-02, -4.8574e-01, -9.4252e-01, -1.1784e+00,-1.0762e+00, -7.0929e-01, -2.3507e+00, -1.5668e+00, -2.8629e+00,-9.7854e-01, -7.7075e-01, -2.1660e+00, -2.3006e-01, -6.7149e-01,-8.6158e-01, -1.7104e-02, -1.9825e+00, -7.7517e-01, -3.8014e-01,-2.1186e+00, -9.2220e-01, -9.2850e-01, -1.2418e+00,  9.7522e-02,-3.6667e-03, -2.1291e+00, -2.8809e+00, -1.3699e+00, -1.5959e+00,-6.5653e-01, -1.2664e+00, -2.8341e-01, -1.5526e+00, -7.1795e-01,-4.8103e-01, -1.6648e+00, -8.2810e-01, -1.6934e+00, -1.3563e+00,-1.6123e+00, -1.1855e+00, -1.2475e+00, -1.3781e+00, -9.8912e-01,-1.3062e-03,  1.2144e+00,  2.8563e+00,  1.7405e+00,  3.0779e-01,8.2037e-01, -4.7336e-01, -2.7651e+00,  4.0167e-01,  2.1637e-01,-5.0109e-01, -1.0902e+00, -2.6263e-01,  5.9031e-01, -5.2879e-01,1.0321e+00,  1.2048e+00,  1.6882e-01,  4.2126e-02, -3.8657e-01,-1.3633e+00,  2.0077e+00, -9.9282e-01, -1.6829e-01, -1.5846e+00,-2.1892e+00, -6.6651e-01,  9.6200e-01,  1.1047e+00, -3.3428e-01,2.7981e+00,  7.2582e-01,  3.4494e-01,  8.2232e-01,  1.7219e+00,1.0106e+00, -2.3200e-01,  4.9711e-02,  1.6123e+00,  8.3826e-01,-1.4559e+00, -2.4328e+00, -2.8555e+00, -2.6156e+00, -1.9900e+00,-2.4778e+00, -1.9356e+00, -1.5563e+00, -2.5033e+00, -3.5848e+00,-2.4205e-01, -5.5758e-01,  2.3322e-01, -1.1810e+00, -8.3212e-01,-4.8195e-02, -4.9411e-01, -3.0698e-03, -1.6134e+00, -1.5790e+00,-5.8626e-01, -1.8875e+00, -1.5670e+00, -2.0681e+00, -1.7590e+00,-3.9325e-01, -2.0172e+00, -1.3237e+00, -1.7693e-01, -8.5266e-01,-2.0535e+00, -2.7916e+00, -1.7173e+00,  5.3713e-02, -1.9363e-01,-3.1787e-01,  7.0567e-01,  5.3067e-01,  1.0458e+00,  1.2243e+00,-3.9257e-01, -3.9865e-01,  3.8122e-01,  3.4527e-01, -1.6836e+00,6.8797e-01,  1.2213e+00,  1.0733e+00,  1.1278e+00,  6.7682e-01,1.2179e+00, -8.0824e-01,  2.7535e-03, -8.5098e-01, -9.4244e-02,-3.7395e-01, -5.9386e-01, -8.1263e-02, -5.8865e-01, -8.3479e-01,-7.2452e-01, -1.6460e-01,  7.2182e-01,  1.2066e+00, -1.8087e+00,-4.4841e-01, -3.2795e-01, -3.0482e-01, -3.3302e-01, -2.4936e+00,-5.7049e-01, -2.0744e-02, -7.5551e-01, -2.4757e+00, -1.7799e+00,-1.1292e+00, -1.0917e+00,  6.8229e-01,  8.7337e-01,  3.1813e+00,-1.5752e+00,  1.0542e-01,  2.5594e+00, -1.0048e+00, -2.2436e+00,4.9551e-01, -2.0745e+00, -9.9214e-01, -2.5501e+00,  2.7392e+00,6.4982e-01,  3.5795e+00,  2.0882e+00,  1.0579e+00,  2.3663e+00,-1.1029e+00, -6.6217e-01, -4.8396e-01,  3.6624e+00,  2.3802e+00,8.2251e-01,  2.5061e+00, -1.8793e+00,  1.6354e+00,  1.9349e+00,7.7006e-01,  2.4251e-01,  1.7568e+00, -9.3206e-01,  1.2631e+00,1.0240e+00, -3.5013e-01,  7.5377e-03,  5.0503e-01, -9.5431e-01,1.5458e+00, -2.5770e+00,  5.7188e-01,  9.7471e-01, -3.1393e-01,1.0891e+00,  2.3057e+00, -7.5324e-01,  3.2789e+00, -8.1716e-01,-1.9879e+00,  5.5330e+00,  6.3507e-01, -1.1635e+00, -1.1235e+00,-3.4298e+00,  7.5610e-01, -3.1293e-02, -9.6185e-01, -8.1488e-02,1.1240e+00, -6.9891e-02,  2.5587e+00,  2.2736e+00,  1.7838e-01,-6.9245e-01,  2.4419e+00,  2.0427e+00,  1.1029e+00,  4.1609e+00,3.5126e+00, -1.8192e+00, -3.3070e+00,  7.6861e-01,  1.2807e+00,2.1298e-01, -8.7622e-01, -2.1935e+00,  1.0431e+00,  1.9949e+00,-3.2491e-01, -3.1093e+00, -1.0409e+00,  1.2334e+00, -1.7676e-01,3.0567e+00,  2.6081e+00,  2.7356e-01,  6.0596e-02, -1.3262e+00,-3.5291e-01, -4.7318e-01,  2.1949e+00,  5.3661e+00,  4.2932e+00,8.3733e+00,  4.1425e-01,  2.4924e-01, -1.3689e+00,  7.1289e-02,-9.8287e-01, -1.2412e+00,  1.3910e+00,  1.9533e+00,  3.3525e+00,1.7242e+00,  1.7637e+00,  1.0108e+00,  1.2255e+00,  1.7504e+00,5.4399e-01,  2.2958e+00,  1.9387e+00,  2.4723e+00, -1.1986e+00,-1.5123e+00, -1.9842e+00,  1.8934e+00,  1.3407e+00,  4.6350e-01,2.6674e+00,  1.0492e+00,  1.0988e+00, -1.4208e-02,  3.9129e-01,-4.7343e-01, -1.7139e+00, -7.8037e-01,  1.3938e+00,  2.4655e+00,-9.8006e-01, -5.5273e-01,  1.1947e+00,  1.5285e+00,  2.2214e-01,2.2346e+00,  1.3524e+00, -3.2841e-01,  2.1160e+00,  4.4156e+00,-2.7112e+00, -9.0547e-01, -1.4378e+00,  1.5687e+00,  3.1633e+00,-2.9853e-01,  1.2451e+00,  2.5149e+00,  1.0312e+00, -6.9518e-01,1.1537e+00,  9.6612e-01, -3.5077e+00, -7.9979e-02,  4.3770e+00,-6.3443e-01, -5.2904e-01,  1.5411e+00,  1.2678e+00, -1.2136e+00,-2.1303e+00,  5.5227e+00,  3.5111e-01,  1.5474e+00,  2.1807e+00,1.4828e+00, -1.4299e+00,  1.9229e+00,  2.4931e+00, -2.5156e+00,-1.7203e+00, -4.2708e-01,  1.6891e+00,  1.5878e+00, -3.3333e+00,2.1083e+00, -1.7954e-02,  3.9262e-01, -1.8340e+00,  7.8696e-01,-2.9308e+00, -2.3592e+00,  1.0347e+00,  8.9930e-01,  1.2392e+00,5.4734e-01,  6.6852e-01, -2.6781e+00,  2.2405e-01, -9.0210e-01,1.0648e+00, -2.3832e+00,  1.7305e+00,  1.6958e+00,  1.0681e+00,8.2608e-01,  2.5071e+00, -2.3054e-01,  3.9594e-01, -1.4630e-01,-2.1682e+00,  3.0358e+00,  1.5096e+00,  7.6303e-01,  4.4392e+00,3.2750e+00,  2.6279e+00,  4.3440e-01, -3.9379e+00,  1.0872e+00,1.7172e+00,  2.8548e+00, -1.0287e+00,  4.9895e+00, -2.0666e+00,4.8006e+00,  2.0120e+00, -1.5181e+00,  8.6181e-01, -3.4666e-01,2.2120e+00,  3.0910e+00,  5.9223e-01,  2.2166e+00,  3.9417e+00,3.5241e+00, -5.3305e-01,  3.5832e+00,  2.5654e+00, -1.5450e+00,-2.6835e+00,  3.1550e+00, -2.6302e+00,  2.3621e-01,  2.1758e+00,1.2487e+00, -1.0268e-01,  3.6262e+00,  3.6049e+00, -2.3248e+00,2.3213e-01,  3.2931e+00, -1.0058e+00,  4.5938e-01, -4.2993e-01,1.3951e+00, -2.8811e-01, -5.2850e-01,  1.0776e+00,  4.6138e+00,-7.1348e-01,  5.8099e-01,  4.4438e-01, -6.0801e-01,  7.0509e-01,3.5084e+00,  3.0626e+00,  7.0831e-01,  1.5073e+00, -2.1074e+00,3.2849e+00, -2.7267e+00,  2.9387e-01,  5.1394e-01,  1.4031e-01,-1.0694e+00, -2.5526e+00,  1.6833e+00, -1.3013e+00,  3.0083e+00,-1.9390e+00,  4.4978e-01, -1.5059e-01, -2.4490e+00,  1.6431e+00,-4.6816e-01, -1.6293e+00, -7.9092e-01,  1.1116e+00,  2.1265e+00,-3.0442e+00,  9.5523e-02,  2.8034e+00,  1.3312e+00,  3.4422e+00,4.4743e-01,  1.7062e+00,  1.8941e-01,  1.2406e+00, -9.8100e-01,-9.7636e-01, -3.9718e-01, -5.6298e-01,  2.1325e+00,  1.4298e+00,-4.6180e+00, -5.8675e-01,  1.7124e+00, -7.3919e-02, -2.9715e+00,2.9501e+00,  1.4472e+00, -1.3756e+00, -1.0018e+00, -1.1162e-01,1.2214e+00, -5.2164e-01, -8.7681e-01,  6.0252e-01,  2.7381e-01,-2.9817e+00, -1.3999e+00,  1.8137e+00, -3.4810e-02,  1.2475e+00,-5.1820e-01,  3.4469e+00,  2.8484e+00,  5.9049e-01,  2.2143e+00,-1.9403e-01,  1.5231e+00, -4.1188e+00,  5.6471e-01, -1.4212e+00,1.1938e+00,  2.8821e+00,  2.4709e+00, -1.6792e+00, -4.7604e-01,1.7501e+00, -2.2566e+00,  7.4556e-01,  2.5034e+00, -3.6194e-01,-1.1058e+00,  2.2076e+00, -6.0705e-03,  2.5470e+00, -1.9637e+00,2.7231e+00,  2.4390e+00,  1.1190e+00, -9.0371e-01, -4.4400e-01,8.6673e-01,  2.8887e+00, -6.5289e-01,  1.6986e+00,  6.0122e-01,-1.1510e+00,  1.9672e+00,  3.6989e+00,  1.3653e-01,  9.0087e-01,1.8489e+00, -2.7983e+00,  1.5802e+00,  2.6502e+00,  1.1414e+00,-5.3817e-01,  1.1085e+00, -2.1715e+00, -7.2016e-01,  1.5999e+00,4.9543e+00,  1.9814e+00, -1.1679e+00,  2.8527e+00,  2.1758e+00,7.5756e-01, -1.0221e+00,  1.2118e+00, -2.4591e-01,  1.4493e+00,3.4529e-01,  1.6389e+00,  4.0479e+00,  1.2619e+00,  4.2199e-01,-1.2010e+00,  2.7446e+00,  3.2914e+00,  1.6454e+00, -4.8627e-01,-3.6592e-01,  1.1508e+00,  4.4760e+00,  3.3516e+00,  2.9289e+00,1.6571e+00, -6.9271e-02,  1.5371e+00, -1.6635e-01,  2.8581e+00,1.0374e+00,  1.1429e+00,  2.1297e+00,  1.0264e+00,  4.7174e+00,-8.5201e-01,  1.7106e+00,  7.4727e-01,  6.5346e-01,  1.6801e+00,-3.7609e-01, -1.5926e+00, -2.6283e+00, -1.6866e+00,  5.5250e-02,-6.2809e-02,  5.9573e-01, -7.4590e-01,  5.3049e-01, -1.5091e+00,-8.0366e-01,  3.3241e+00,  2.3141e+00,  1.1193e+00, -1.6830e+00,3.3035e+00,  2.9134e-01, -2.9930e+00,  2.4471e+00,  9.8725e-01,-2.7953e+00, -1.7308e+00, -9.4977e-01,  1.6247e-01,  2.5793e+00,2.9449e-01,  2.1876e+00,  1.3091e-01,  6.2929e+00, -5.5488e-01,1.2929e+00, -9.5095e-03, -1.1349e+00, -1.0178e-01,  2.3317e+00,-4.3678e-01,  2.3839e+00,  2.6191e+00, -2.0215e+00,  1.5188e+00,3.1490e+00,  3.1997e+00, -2.2047e-01, -1.2029e-01,  2.7171e+00,3.1623e+00,  7.7251e-01, -1.8028e+00, -7.3017e-01,  1.5781e+00,7.6143e-01,  4.7296e+00,  1.7691e+00,  1.4732e+00,  2.0614e+00,2.2509e+00, -4.4578e+00,  1.1764e+00,  2.2630e+00,  5.7318e-01,4.3310e-01,  1.6570e+00, -1.4352e+00, -1.2535e+00, -4.0429e+00,-5.1775e-01, -1.5580e+00, -1.8145e+00,  2.4469e+00,  1.9574e+00,-2.0032e-01, -2.0393e+00,  3.3668e+00, -5.2449e-01, -4.5653e+00,4.8361e-01,  4.8011e-01,  8.3248e-01, -1.4842e-01,  2.5230e+00,-3.1912e-01,  1.1091e+00,  1.9290e+00,  6.5501e-01,  7.5642e-01,1.3678e+00,  1.6187e+00, -2.2867e+00, -1.3338e+00,  7.0305e-01,-2.6969e+00, -3.4848e-01,  3.5779e+00,  2.5296e+00,  1.2646e+00,-8.2202e-01,  1.5727e+00,  2.0048e+00,  1.9939e+00,  3.6664e-01,-3.7189e-01,  6.5360e-02,  2.5970e+00,  1.9509e+00,  7.9060e+00,4.1564e+00,  1.9750e+00,  1.3692e+00,  7.0074e-01,  1.3194e+00,1.5737e+00,  3.1158e+00,  2.8220e-01, -1.1930e+00, -2.9132e+00,3.6715e-01,  2.0554e+00, -4.5951e-01,  1.4659e+00,  1.6097e-01,3.5082e-01,  1.9813e+00,  2.3234e+00, -1.6767e+00, -1.9703e+00,-4.2028e-01, -2.6262e+00, -1.3928e+00, -7.6662e-01,  4.5116e-01,2.6828e-01, -2.8156e-01,  7.0492e-02, -2.3663e+00, -5.0179e-01,-1.6241e-01, -2.5555e+00, -9.8973e-02, -2.2130e+00, -2.3067e+00,-1.8250e+00, -1.8571e+00, -2.4779e+00, -2.7528e+00, -2.9528e+00,-9.4892e-01, -2.8599e+00, -6.0309e-01, -1.4899e-01, -9.7413e-01,9.2476e-01,  1.2974e+00, -8.6647e-01, -1.4522e-01,  1.5039e+00,1.5240e-01, -1.9550e+00, -1.3404e+00,  5.6667e-01, -1.2009e+00,-9.4940e-01,  1.0278e+00, -2.9112e+00, -6.9027e-01, -8.4326e-01,-1.5937e+00,  1.6618e+00,  3.1860e+00,  3.0757e+00,  4.0690e-01,-1.1017e+00,  3.6284e+00, -6.9720e-01, -1.3498e+00,  1.4283e-01,-4.1820e-01, -1.6470e+00,  4.1369e-01,  1.7120e-01, -1.7615e+00,7.3642e-01,  1.7452e+00,  4.3359e-01, -2.8788e-01, -6.6571e-02,-1.4325e-02, -2.2441e+00,  1.2690e+00, -7.3996e-01, -1.1551e+00,-1.4367e+00, -1.5546e+00, -2.9878e+00, -3.5215e+00, -4.2169e+00,-3.7416e+00, -2.0244e+00, -2.6461e+00, -1.1108e+00,  1.1864e+00]],grad_fn=<AddmmBackward0>)
torch_out.size()
torch.Size([1, 1000])

# 導(dǎo)出模型
torch.onnx.export(loaded_model,               # model being runx,             # model input (or a tuple for multiple inputs)onnx_file,   # where to save the model (can be a file or file-like object)export_params=True,        # store the trained parameter weights inside the model fileopset_version=10,   # the ONNX version to export the model todo_constant_folding=True,  # whether to execute constant folding for optimizationinput_names = ['conv1'],   # the model's input namesoutput_names = ['fc'], # the model's output names# variable length axesdynamic_axes={'conv1' : {0 : 'batch_size'},    'fc' : {0 : 'batch_size'}})
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

4、檢驗ONNX模型

# 我們可以使用異常處理的方法進行檢驗
try:# 當(dāng)我們的模型不可用時,將會報出異常onnx.checker.check_model(onnx_file)
except onnx.checker.ValidationError as e:print("The model is invalid: %s"%e)
else:# 模型可用時,將不會報出異常,并會輸出“The model is valid!”print("The model is valid!")
The model is valid!

5. 使用ONNX Runtime進行推理

import onnxruntime
import numpy as nport_session = onnxruntime.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])# 將張量轉(zhuǎn)化為ndarray格式
def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 構(gòu)建輸入的字典和計算輸出結(jié)果
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)# 比較使用PyTorch和ONNX Runtime得出的精度
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)print("Exported model has been tested with ONNXRuntime, and the result looks good!")
Exported model has been tested with ONNXRuntime, and the result looks good!

6. 進行實際預(yù)測并可視化

# 推理數(shù)據(jù)
from PIL import Image
from torchvision.transforms import transforms# 生成推理圖片
image = Image.open('./images/cat.jpg')# 將圖像調(diào)整為指定大小
image = image.resize((224, 224))# 將圖像轉(zhuǎn)換為 RGB 模式
image = image.convert('RGB')image.save('./images/cat_224.jpg')
categories = []
# Read the categories
with open("./imagenet/imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]def get_class_name(probabilities):# Show top categories per imagetop5_prob, top5_catid = torch.topk(probabilities, 5)for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())
#預(yù)處理
def pre_image(image_file):input_image = Image.open(image_file)preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])input_tensor = preprocess(input_image)inputs = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model# input_arr = inputs.cpu().detach().numpy()return inputs 
#inference with model# 先加載模型結(jié)構(gòu)
resnet50 = torchvision.models.resnet50()   
# 在加載模型權(quán)重
resnet50.load_state_dict(torch.load(save_dir))resnet50.eval()  
#推理
input_batch = pre_image('./images/cat_224.jpg')# move the input and model to GPU for speed if available
print("GPU Availability: ", torch.cuda.is_available())
if torch.cuda.is_available():input_batch = input_batch.to('cuda')resnet50.to('cuda')with torch.no_grad():output = resnet50(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
# print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
get_class_name(probabilities)
GPU Availability:  False
Persian cat 0.6668420433998108
lynx 0.023987364023923874
bow tie 0.016234245151281357
hair slide 0.013150070793926716
Japanese spaniel 0.012279157526791096
input_batch.size()
torch.Size([1, 3, 224, 224])
#benchmark 性能
latency = []
for i in range(10):with torch.no_grad():start = time.time()output = resnet50(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0)top5_prob, top5_catid = torch.topk(probabilities, 5)# for catid in range(top5_catid.size(0)):#     print(categories[catid])latency.append(time.time() - start)print("{} model inference CPU time:cost {} ms".format(str(i),format(sum(latency) * 1000 / len(latency), '.2f')))
0 model inference CPU time:cost 149.59 ms
1 model inference CPU time:cost 130.74 ms
2 model inference CPU time:cost 133.76 ms
3 model inference CPU time:cost 130.64 ms
4 model inference CPU time:cost 131.72 ms
5 model inference CPU time:cost 130.88 ms
6 model inference CPU time:cost 136.31 ms
7 model inference CPU time:cost 139.95 ms
8 model inference CPU time:cost 141.90 ms
9 model inference CPU time:cost 140.96 ms
# Inference with ONNX Runtime
import onnxruntime
from onnx import numpy_helper
import time
onnx_file = 'resnet50.onnx'
session_fp32 = onnxruntime.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['CUDAExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['OpenVINOExecutionProvider'])def softmax(x):"""Compute softmax values for each sets of scores in x."""e_x = np.exp(x - np.max(x))return e_x / e_x.sum()latency = []
def run_sample(session, categories, inputs):start = time.time()input_arr = inputsort_outputs = session.run([], {'conv1':input_arr})[0]output = ort_outputs.flatten()output = softmax(output) # this is optionaltop5_catid = np.argsort(-output)[:5]# for catid in top5_catid:#     print(categories[catid])latency.append(time.time() - start)return ort_outputs

input_tensor = pre_image('./images/cat_224.jpg')
input_arr = input_tensor.cpu().detach().numpy()
for i in range(10):ort_output = run_sample(session_fp32, categories, input_arr)print("{} ONNX Runtime CPU Inference time = {} ms".format(str(i),format(sum(latency) * 1000 / len(latency), '.2f')))
0 ONNX Runtime CPU Inference time = 67.66 ms
1 ONNX Runtime CPU Inference time = 56.30 ms
2 ONNX Runtime CPU Inference time = 53.90 ms
3 ONNX Runtime CPU Inference time = 58.18 ms
4 ONNX Runtime CPU Inference time = 64.53 ms
5 ONNX Runtime CPU Inference time = 62.79 ms
6 ONNX Runtime CPU Inference time = 61.75 ms
7 ONNX Runtime CPU Inference time = 60.51 ms
8 ONNX Runtime CPU Inference time = 59.35 ms
9 ONNX Runtime CPU Inference time = 57.57 ms

4、擴展知識

  • 模型量化
  • 模型剪裁
  • 工程優(yōu)化
  • 算子優(yōu)化

http://aloenet.com.cn/news/36756.html

相關(guān)文章:

  • wordpress 添加icp舟山百度seo
  • 網(wǎng)站右邊跳出的廣告怎么做可以推廣的軟件有哪些
  • 惠民網(wǎng)站建設(shè)淺議網(wǎng)絡(luò)營銷論文
  • 休閑旅游產(chǎn)品營銷網(wǎng)站的建設(shè)策略app運營
  • 深圳高端網(wǎng)站開發(fā)網(wǎng)絡(luò)營銷的優(yōu)勢有哪些?
  • 外貿(mào)網(wǎng)站推廣實操手冊2023年新聞小學(xué)生摘抄
  • 怎么建立自己的網(wǎng)站平臺多少錢南寧seo優(yōu)化公司排名
  • 黃江鎮(zhèn)做網(wǎng)站在線生成個人網(wǎng)站源碼
  • 自己做網(wǎng)站 發(fā)布視頻北京自動seo
  • 建設(shè)部舉報網(wǎng)站西安網(wǎng)站建設(shè)方案優(yōu)化
  • 凡科建站可以多人協(xié)作編輯嗎北京百度快照推廣公司
  • 渭南做網(wǎng)站的公司河北高端網(wǎng)站建設(shè)
  • 網(wǎng)站開發(fā)人員要求科技網(wǎng)站建設(shè)公司
  • 綁定網(wǎng)站品牌策劃與推廣方案
  • 建一個購物網(wǎng)站需要多少錢邯鄲今日頭條最新消息
  • 免費個人網(wǎng)站建設(shè)可口可樂軟文營銷案例
  • 惠州網(wǎng)站開發(fā)公司網(wǎng)頁
  • 建網(wǎng)站怎么避免備案百度推廣課程
  • 可以做哪些網(wǎng)站我想自己建立一個網(wǎng)站
  • 做網(wǎng)站怎么找優(yōu)質(zhì)客戶軟文寫作技巧有哪些
  • 公安內(nèi)網(wǎng)網(wǎng)站建設(shè)方案站群seo
  • 個人網(wǎng)站備案需要哪些資料網(wǎng)站單向外鏈推廣工具
  • 做網(wǎng)站 單頁數(shù)量網(wǎng)絡(luò)營銷的步驟
  • 網(wǎng)站建設(shè)費用細項廣州seo推廣公司
  • 中英網(wǎng)站的設(shè)計app開發(fā)費用一覽表
  • 上海網(wǎng)站建設(shè)百家號廣告投放推廣平臺
  • 網(wǎng)站建設(shè)石家莊今天國際新聞大事
  • 制作網(wǎng)站接單seo關(guān)鍵詞排名如何
  • 有哪些做買家秀的網(wǎng)站企業(yè)營銷平臺
  • 酒店網(wǎng)站策劃書網(wǎng)站打開