網(wǎng)站沒有問題但是一直做不上首頁seo托管
目錄
寫在開頭
一、CNN的原理
1. 概述
2. 卷積層
內(nèi)參數(shù)(卷積核本身)
外參數(shù)(填充和步幅)
輸入與輸出的尺寸關(guān)系?
3. 多通道問題?
多通道輸入
多通道輸出
4. 池化層
平均匯聚
最大值匯聚
二、手寫數(shù)字識別
1. 任務(wù)描述和數(shù)據(jù)集加載
2. 網(wǎng)絡(luò)結(jié)構(gòu)(LeNet-5)
3. 模型訓(xùn)練
4. 模型測試
5. 直觀顯示預(yù)測結(jié)果
寫在最后
寫在開頭
? ? 本文將將介紹如何使用PyTorch框架搭建卷積神經(jīng)網(wǎng)絡(luò)(CNN)模型。簡易介紹卷積神經(jīng)網(wǎng)絡(luò)的原理,并實現(xiàn)模型的搭建、數(shù)據(jù)集加載、模型訓(xùn)練、測試、網(wǎng)絡(luò)的保存等。實現(xiàn)機器學(xué)習(xí)領(lǐng)域的Hello world——手寫數(shù)字識別。本文的講解參考了B站兩位up主:爆肝杰哥、炮哥帶你學(xué)。有關(guān)Pytorch環(huán)境配置和CNN具體原理大家可以自行查閱資料,本文多數(shù)圖片也來自于爆肝杰哥的講解。這里也放上兩位up主的視頻鏈接:
從0開始擼代碼--手把手教你搭建LeNet-5網(wǎng)絡(luò)模型_嗶哩嗶哩_bilibili
?Python深度學(xué)習(xí):安裝Anaconda、PyTorch(GPU版)庫與PyCharm_嗶哩嗶哩_bilibili
? ? 本文使用的PyTorch為1.12.0版本,Numpy為1.21版本,相近的版本語法差異很小。有關(guān)數(shù)組的數(shù)據(jù)結(jié)構(gòu)教程、神經(jīng)網(wǎng)絡(luò)的基本原理(前向傳播/反向傳播)、神經(jīng)網(wǎng)絡(luò)作為“函數(shù)模擬器”直觀感受、深度神經(jīng)網(wǎng)絡(luò)的實現(xiàn)DNN詳見本專欄的前三篇文章,鏈接如下:?
【深度學(xué)習(xí)基礎(chǔ)】NumPy數(shù)組庫的使用-CSDN博客
【深度學(xué)習(xí)基礎(chǔ)】用PyTorch從零開始搭建DNN深度神經(jīng)網(wǎng)絡(luò)_如何搭建一個深度學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)dnn pytorch-CSDN博客
【深度學(xué)習(xí)基礎(chǔ)】使用Pytorch搭建DNN深度神經(jīng)網(wǎng)絡(luò)與手寫數(shù)字識別_dnn網(wǎng)絡(luò)模型 代碼-CSDN博客
? ? 基于深度神經(jīng)網(wǎng)絡(luò)DNN實現(xiàn)的手寫數(shù)字識別,將灰度圖像轉(zhuǎn)換后的二維數(shù)組展平到一維,將一維的784個特征作為模型輸入。在“展平”的過程中必然會失去一些圖像的形狀結(jié)構(gòu)特征,因此基于DNN的實現(xiàn)方式并不能很好的利用圖像的二維結(jié)構(gòu)特征,而卷積神經(jīng)網(wǎng)絡(luò)CNN對于處理圖像的位置信息具有一定的優(yōu)勢。因此卷積神經(jīng)網(wǎng)絡(luò)經(jīng)常被用于圖像識別/處理領(lǐng)域。下面我們將對CNN進行具體介紹。
一、CNN的原理
1. 概述
? ? 在上一篇博客介紹的深度神經(jīng)網(wǎng)絡(luò)DNN中,網(wǎng)絡(luò)的每一層神經(jīng)元相互直接都有鏈接,每一層都是全連接層,我們的目標(biāo)就是訓(xùn)練這個全連接層的權(quán)重w和偏執(zhí)b,最終得到預(yù)測效果良好的網(wǎng)絡(luò)結(jié)構(gòu)。
? ? DNN的全連接層對應(yīng)于CNN中的卷積層,而池化層(匯聚)其實與激活函數(shù)的作用類似。CNN中完整的卷積層的結(jié)構(gòu)是:卷積-激活函數(shù)-池化(匯聚),其中池化層也有時可以省略。一個卷積神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)如下:
? ? 如上圖所示,CNN的優(yōu)勢在于可以處理多為輸入數(shù)據(jù),并同樣以多維數(shù)據(jù)的形式輸出至下一層,保留了更多的空間信息特征。而DNN卻只能將多維數(shù)據(jù)展平成一維數(shù)據(jù),必然會損失一些空間特征。
2. 卷積層
內(nèi)參數(shù)(卷積核本身)
? ? CNN中的卷積層和DNN中的全連接層是平級關(guān)系,在DNN中,我們訓(xùn)練的內(nèi)參數(shù)是全連接層的權(quán)重w和偏置b,CNN也類似,CNN訓(xùn)練的是卷積核,也就相當(dāng)于包含了權(quán)重和偏置兩個內(nèi)部參數(shù)。下面我們首先描述什么是卷積運算。當(dāng)輸入數(shù)據(jù)進入卷積層后,輸入數(shù)據(jù)會與卷積核進行卷積運算,運算方法如下圖所示:
? ? ?輸入一個多維數(shù)據(jù)(上圖為二維),與卷積核進行運算,即輸入中與卷積核形狀相同的部分,分別與卷積核進行逐個元素相乘再相加。例如計算結(jié)果中坐上角的15是根據(jù)如下過程計算得到的:
逐個元素相乘再相加,即:
1 * 2 + 2 * 0 + 3* 1 + 0 * 0 + 1 * 1 + 2 * 2 + 3 * 1 + 0 * 0 + 1 * 2 = 15?
? ? ?卷積核本身相當(dāng)于權(quán)重,再卷積運算的過程中也可以存在偏置,如下:
? ? 卷積核(即CNN的權(quán)重和偏置)本身為內(nèi)參數(shù),(具體里面的數(shù)字)是我們通過訓(xùn)練得出的,我們寫代碼的時候只要關(guān)注一些外部設(shè)定的參數(shù)即可。下面我們將介紹一些外參數(shù)。
外參數(shù)(填充和步幅)
? ?填充(padding)
? ?顯然,只要卷積核的大小>1*1,必然會導(dǎo)致圖像越卷越小,為了防止輸入經(jīng)過多個卷積層后變得過小,可以在發(fā)生卷積層之前,先向輸入圖像的外圍填入固定的數(shù)據(jù)(比如0),這個步驟稱之為填充,如下圖:
? ? ?在我們使用Pytorch搭建卷積層的時候,需要在對應(yīng)的接口中添加這個padding參數(shù),向上圖中這種情況,相當(dāng)于在3*3的卷積核外圍添加了“一圈”,則padding = 1,卷積層的接口中就要這樣寫:
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, paddding=1)
? ? ??參數(shù)in_channels和out_channels是對應(yīng)于這個卷積層輸入和輸出的通道數(shù)參數(shù),這里我們先放一放。
? ?步幅(stride)
? ?步幅指的是使用卷積核的位置間隔,即輸入中參與運算的那個范圍每次移動的距離。前面幾個示意圖中的步幅均為1,即每次移動一格,如果設(shè)置stride=2,kernel_size=2,則效果如下:
? ? ?此時需要在卷積層接口中添加參數(shù)stride=2。
輸入與輸出的尺寸關(guān)系?
? ?綜上所述,結(jié)合外參數(shù)(步幅、填充)和內(nèi)參數(shù)(卷積核),可以看出如下規(guī)律:
卷積核越大,輸出越小。
步幅越大,輸出越小。
填充越大,輸出越大。
? ? ?用公式表示定量關(guān)系:
? ? ? 如果輸入和卷積核均為方陣,設(shè)輸入尺寸為W*W,輸出尺寸為N*N,卷積核尺寸為F*F,填充的圈數(shù)為P,步幅為S,則有關(guān)系:
? ? 這個關(guān)系大家要重點掌握,也可以自己推導(dǎo)一下,并不復(fù)雜。如果輸入和卷積核不為方陣,設(shè)輸入尺寸是H*W,輸出尺寸是OH*OW,卷積核尺寸為FH*FW,填充為P,步幅為S,則輸出尺寸OH*OW的計算公式是:
?
3. 多通道問題?
多通道輸入
? ? 對于手寫數(shù)字識別這種灰度圖像,可以視為僅有(高*長)二維的輸入。然而,對于彩色圖像,每一個像素點都相當(dāng)于是RGB的三個值的組合,因此對于彩色的圖像輸入,除了高*長兩個維度外,還有第三個維度——通道,即紅、綠、藍(lán)三個通道,也可以視為3個單通道的二維圖像的混合疊加。
當(dāng)輸入數(shù)據(jù)僅為二維時,卷積層的權(quán)重往往被稱作卷積核(Kernel);
當(dāng)輸入數(shù)據(jù)為三維或更高時,卷積層的權(quán)重往往被稱作濾波器(Filter)。
? ? ?對于多通道輸入,輸入數(shù)據(jù)和濾波器的通道數(shù)必須保持一致。這樣會導(dǎo)致輸出結(jié)果降維成二維,如下圖:
? ? ?對形狀進行一下抽象,則輸入數(shù)據(jù)C*H*W和濾波器C*FH*FW都是長方體,結(jié)果是一個長方形1*OH*OW,注意C,H,W是固定的順序,通道數(shù)要寫在最前。
多通道輸出
?? ?如果要實現(xiàn)多通道輸出,那么就需要多個濾波器,讓三維輸入與多個濾波器進行卷積,就可以實現(xiàn)多通道輸出,輸出的通道數(shù)FN就是濾波器的個數(shù)FN,如下圖:
? ? 和單通道一樣,卷積運算后也有偏置,如果進一步追加偏置,則結(jié)果如下:每個通道都有一個單獨的偏置。?
4. 池化層
? ?池化,也叫匯聚(Pooling)。池化層通常位于卷積層之后(有時也可以不設(shè)置池化層),其作用僅僅是在一定范圍內(nèi)提取特征值,所以并不存在要學(xué)習(xí)的內(nèi)部參數(shù)。池化僅僅對圖像的高H和寬W進行特征提取,并不改變通道數(shù)C。
平均匯聚
一般有平均匯聚和最大值匯聚兩種。平均匯聚如下:
? ? 如上圖,池化的窗口大小為2*2,對應(yīng)的步幅為2,因此對于上圖這種情況,對應(yīng)的Pytorch接口如下:
nn.AvgPool2d(kernel_size=2, stride=2)?
最大值匯聚
? ?同理,如果使用最大值匯聚,如下圖所示:
? ? 此處Pytorch函數(shù)就這么寫:
nn.MaxPool2d(kernel_size=2, stride=2)?
二、手寫數(shù)字識別
1. 任務(wù)描述和數(shù)據(jù)集加載
? ? 此處和上一篇博客類似,詳情見:
【深度學(xué)習(xí)基礎(chǔ)】使用Pytorch搭建DNN深度神經(jīng)網(wǎng)絡(luò)與手寫數(shù)字識別_dnn網(wǎng)絡(luò)模型 代碼-CSDN博客
? ? ?接下來我們實現(xiàn)機器學(xué)習(xí)領(lǐng)域的Hello World——手寫數(shù)字識別。使用的數(shù)據(jù)集MNIST是機器學(xué)習(xí)領(lǐng)域的標(biāo)準(zhǔn)數(shù)據(jù)集,其中的每一個樣本都是一副二維的灰度圖像,尺寸為28*28:
? ?輸入就相當(dāng)于一個單通道的圖像,是二維的。我們在實現(xiàn)的時候,要將每個樣本圖像轉(zhuǎn)換為28*28的張量,作為輸入,此處和上一篇DNN都一致。數(shù)據(jù)集則通過包torchvision中的datasets庫進行下載。這里我快速給一段代碼好了,詳情可見上一篇博客。
import torch
from torchvision import datasets, transforms# 設(shè)定下載參數(shù) (數(shù)據(jù)集轉(zhuǎn)換參數(shù)),將圖像轉(zhuǎn)換為張量
data_transform = transforms.Compose([transforms.ToTensor()
])# 加載訓(xùn)練數(shù)據(jù)集
train_dataset = datasets.MNIST(root='D:\\Jupyter\\dataset\\minst', # 下載路徑,讀者請自行設(shè)置train=True, # 是訓(xùn)練集download=True, # 如果該路徑?jīng)]有該數(shù)據(jù)集,則進行下載transform=data_transform # 數(shù)據(jù)集轉(zhuǎn)換參數(shù)
)# 批次加載器,在接下來的訓(xùn)練中進行小批次(16批次)的載入數(shù)據(jù),有助于提高準(zhǔn)確度,對訓(xùn)練集的樣本進行打亂,
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)# 加載測試數(shù)據(jù)集
test_dataset = datasets.MNIST(root='D:\\Jupyter\\dataset\\minst', # 下載路徑train=False, # 是訓(xùn)練集download=True, # 如果該路徑?jīng)]有該數(shù)據(jù)集,則進行下載transform=data_transform # 數(shù)據(jù)集轉(zhuǎn)換參數(shù)
)test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)
2. 網(wǎng)絡(luò)結(jié)構(gòu)(LeNet-5)
? ?本文搭建的LeNet-5起源于1998年,在手寫數(shù)字識別上非常成功。其結(jié)構(gòu)如下:
再列一個表格,具體結(jié)構(gòu)如下:
? ? ?注:輸出層的激活函數(shù)目前已經(jīng)被Softmax取代。
? ? ?至于這些尺寸關(guān)系,我舉個兩例子吧:
? ?以第一層C1的輸入和輸出為例。輸入尺寸W是28*28,卷積核F尺寸為5*5,步幅S為1,填充P為2,那么輸出N的28*28怎么來的呢?按照公式如下:
? ?我們也可以觀察到第一層的卷積核個數(shù)為6,則輸出的通道數(shù)也為6。
? 再看一下第一個池化層S2,輸入尺寸是28*28,卷積核F大小為2*2(此處的“卷積核”實際上指的是采樣范圍),步幅S=2,填充P為0,則輸出的14*14是這么算出來的:
? 其他的沒啥好說的,讀者們可以自行計算這個尺寸關(guān)系。接下來我們給出完整的CNN網(wǎng)絡(luò)代碼,net.py如下:
import torch
from torch import nn# 定義網(wǎng)絡(luò)模型
class MyLeNet5(nn.Module):# 初始化網(wǎng)絡(luò)def __init__(self):super(MyLeNet5, self).__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5), nn.Tanh(),nn.Flatten(),nn.Linear(120, 84), nn.Tanh(),nn.Linear(84, 10))# 前向傳播def forward(self, x):y = self.net(x)return y# 以下為測試代碼,也可不添加
if __name__ == '__main__': x1 = torch.rand([1, 1, 28, 28])model = MyLeNet5()y1 = model(x1)print(x1)print(y1)
? ? ?這里我還在'__main__'添加了一些測試代碼,正如表格中所示,假設(shè)我們向網(wǎng)絡(luò)中輸入一個1*1*28*28的向量,模擬批次大小為1,一個單通道28*28的灰度圖輸入。最終y應(yīng)該是由10個數(shù)字組成的張量,結(jié)果如下:
? ? ?從輸出我們可以直觀的看到輸入經(jīng)過神經(jīng)網(wǎng)絡(luò)前向傳播后的結(jié)果。另外特別注意這個網(wǎng)絡(luò)的結(jié)構(gòu)中的參數(shù),其中卷積層的搭建API有5個外參數(shù):
in_channels:輸入通道數(shù)
out_channels:輸出通道數(shù)
kernel_size: 卷積核尺寸
padding: 填充,不寫則默認(rèn)0
stride: 步幅,不寫則默認(rèn)1
? ? ?這個LeNet-5網(wǎng)絡(luò)結(jié)構(gòu)就長這樣,我們一定要嚴(yán)格遵守,否則有可能出現(xiàn)無論怎么訓(xùn)練,都始終欠擬合的情況。我就曾經(jīng)試過更改/添加不同的激活函數(shù),結(jié)果無論是訓(xùn)練還是測試,準(zhǔn)確率都在10%徘徊,相當(dāng)于隨機瞎猜的效果,因此大家一定要嚴(yán)格遵循這個網(wǎng)絡(luò)結(jié)構(gòu)。?
3. 模型訓(xùn)練
? ? 網(wǎng)絡(luò)搭建好之后,所有的內(nèi)參數(shù)(即卷積核)都是隨機的,下面我們要通過訓(xùn)練盡可能提高網(wǎng)絡(luò)的預(yù)測能力。在訓(xùn)練前,我們首先要選擇損失函數(shù)(這里使用交叉熵?fù)p失函數(shù)),定義優(yōu)化器、進行學(xué)習(xí)率調(diào)整等,代碼如下:
import torch
from torch import nn
from net import MyLeNet5
from torch.optim import lr_scheduler# 判斷是否有g(shù)pu
device = "cuda" if torch.cuda.is_available() else "cpu"# 調(diào)用net,將模型數(shù)據(jù)轉(zhuǎn)移到gpu
model = MyLeNet5().to(device)# 選擇損失函數(shù)
loss_fn = nn.CrossEntropyLoss() # 交叉熵?fù)p失函數(shù),自帶Softmax激活函數(shù)# 定義優(yōu)化器
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)# 學(xué)習(xí)率每隔10輪次, 變?yōu)樵瓉淼?.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
? ? 然后我們可以寫一個用于訓(xùn)練網(wǎng)絡(luò)的函數(shù),四個參數(shù)分別是批次加載器、模型、損失函數(shù)、優(yōu)化器,代碼如下:
# 定義模型訓(xùn)練的函數(shù)
def train(dataloader, model, loss_fn, optimizer):loss, current, n = 0.0, 0.0, 0for batch, (X, y) in enumerate(dataloader):# 前向傳播X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)# 用_和pred分別接收輸出10個元素中的最大值和對應(yīng)下標(biāo)位置_, pred = torch.max(output, dim=1)# 計算當(dāng)前輪次時,訓(xùn)練集的精確度,將所有標(biāo)簽值與預(yù)測值(即下標(biāo)位置)cur_acc = torch.sum(y == pred)/output.shape[0]# 反向傳播,對內(nèi)部參數(shù)(卷積核)進行優(yōu)化optimizer.zero_grad()cur_loss.backward()optimizer.step()# 計算準(zhǔn)確率和損失,這里只是為了實時顯示訓(xùn)練集的擬合情況。也可以不寫loss += cur_loss.item()current += cur_acc.item()n = n + 1print("train_loss: ", str(loss/n))print("train_acc: ", str(current/n))
只要調(diào)用這個函數(shù),即可實現(xiàn)模型訓(xùn)練:
train(train_dataloader, model, loss_fn, optimizer)
當(dāng)然,我們最好是設(shè)定一個輪次epoch,我們后續(xù)會寫這樣一個循環(huán),每訓(xùn)練一個epoch,就進行一次測試,實時顯示一定輪次后訓(xùn)練集和測試集的擬合情況。
4. 模型測試
? ?這里和模型訓(xùn)練類似,只不過我們要觀察訓(xùn)練好的模型,在測試集的預(yù)測效果。與訓(xùn)練的代碼相似,只是沒有了反向傳播優(yōu)化參數(shù)的過程。用于測試的函數(shù)代碼如下:
def test(dataloader, model, loss_fn):model.eval()loss, current, n = 0.0, 0.0, 0# 該局部關(guān)閉梯度計算功能,提高運算效率with torch.no_grad(): for batch, (X, y) in enumerate(dataloader):# 前向傳播X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)_, pred = torch.max(output, dim=1)# 計算當(dāng)前輪次時,訓(xùn)練集的精確度cur_acc = torch.sum(y == pred) / output.shape[0]loss += cur_loss.item()current += cur_acc.item()n = n + 1print("test_loss: ", str(loss / n))print("test_acc: ", str(current / n))return current/n # 返回精確度
? ? 如上代碼,將測試集的精確度作為返回值,我們在外圍調(diào)用這個函數(shù)時,可以通過循環(huán)找到測試集最大的精確度。
? ? 最終我們設(shè)定一個訓(xùn)練輪次epochs,此處epochs=50,每經(jīng)過一個epoch的訓(xùn)練,就進行測試,實時打印觀察訓(xùn)練集和測試集的擬合情況。當(dāng)測試集的精確度是當(dāng)前的最大值時,我們就保存這個模型的參數(shù)到save_model/best_model.pth,代碼如下:
import os# 開始訓(xùn)練
epoch = 50
max_acc = 0
for t in range(epoch):print(f"epoch{t+1}\n---------------")# 訓(xùn)練模型train(train_dataloader, model, loss_fn, optimizer)# 測試模型a = test(test_dataloader, model, loss_fn)# 保存最好的模型參數(shù)if a > max_acc:folder = 'save_model'if not os.path.exists(folder):os.mkdir(folder)max_acc = aprint("current best model acc = ", a)torch.save(model.state_dict(), 'save_model/best_model.pth')
print("Done!")
? ? 運行之后可以發(fā)現(xiàn),測試集的精度經(jīng)過1個epochs就達(dá)到了90%以上,最終經(jīng)過50輪次的訓(xùn)練,測試集精度達(dá)到了99%左右:
? ?模型參數(shù)也得以保存:
? ? 最后給出用于訓(xùn)練和測試的完整代碼train.py,如下所示:
import torch
from torch import nn
from net import MyLeNet5
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import os# 將圖像轉(zhuǎn)換為張量形式
data_transform = transforms.Compose([transforms.ToTensor()
])# 加載訓(xùn)練數(shù)據(jù)集
train_dataset = datasets.MNIST(root='D:\\Jupyter\\dataset\\minst', # 下載路徑train=True, # 是訓(xùn)練集download=True, # 如果該路徑?jīng)]有該數(shù)據(jù)集,則進行下載transform=data_transform # 數(shù)據(jù)集轉(zhuǎn)換參數(shù)
)# 批次加載器
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)# 加載測試數(shù)據(jù)集
test_dataset = datasets.MNIST(root='D:\\Jupyter\\dataset\\minst', # 下載路徑train=False, # 是訓(xùn)練集download=True, # 如果該路徑?jīng)]有該數(shù)據(jù)集,則進行下載transform=data_transform # 數(shù)據(jù)集轉(zhuǎn)換參數(shù)
)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)# 判斷是否有g(shù)pu
device = "cuda" if torch.cuda.is_available() else "cpu"# 調(diào)用net,將模型數(shù)據(jù)轉(zhuǎn)移到gpu
model = MyLeNet5().to(device)# 選擇損失函數(shù)
loss_fn = nn.CrossEntropyLoss() # 交叉熵?fù)p失函數(shù),自帶Softmax激活函數(shù)# 定義優(yōu)化器
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)# 學(xué)習(xí)率每隔10輪次, 變?yōu)樵瓉淼?.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 定于訓(xùn)練函數(shù)
def train(dataloader, model, loss_fn, optimizer):loss, current, n = 0.0, 0.0, 0for batch, (X, y) in enumerate(dataloader):# 前向傳播X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)_, pred = torch.max(output, dim=1)# 計算當(dāng)前輪次時,訓(xùn)練集的精確度cur_acc = torch.sum(y == pred)/output.shape[0]# 反向傳播optimizer.zero_grad()cur_loss.backward()optimizer.step()loss += cur_loss.item()current += cur_acc.item()n = n + 1print("train_loss: ", str(loss/n))print("train_acc: ", str(current/n))def test(dataloader, model, loss_fn):model.eval()loss, current, n = 0.0, 0.0, 0# 該局部關(guān)閉梯度計算功能,提高運算效率with torch.no_grad():for batch, (X, y) in enumerate(dataloader):# 前向傳播X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)_, pred = torch.max(output, dim=1)# 計算當(dāng)前輪次時,訓(xùn)練集的精確度cur_acc = torch.sum(y == pred) / output.shape[0]loss += cur_loss.item()current += cur_acc.item()n = n + 1print("test_loss: ", str(loss / n))print("test_acc: ", str(current / n))return current/n # 返回精確度# 開始訓(xùn)練
epoch = 50
max_acc = 0
for t in range(epoch):print(f"epoch{t+1}\n---------------")train(train_dataloader, model, loss_fn, optimizer)a = test(test_dataloader, model, loss_fn)# 保存最好的模型參數(shù)if a > max_acc:folder = 'save_model'if not os.path.exists(folder):os.mkdir(folder)max_acc = aprint("current best model acc = ", a)torch.save(model.state_dict(), 'save_model/best_model.pth')
print("Done!")
5. 直觀顯示預(yù)測結(jié)果
? ? ?截至目前,我們已經(jīng)完成了手寫數(shù)字識別這個任務(wù),但是我們好像對于數(shù)據(jù)集長什么樣并不是很了解,似乎僅僅是用torchvision中的datasets庫下載了一下。因此本小節(jié),我們的目標(biāo)是從數(shù)據(jù)集取出幾個特定的手寫數(shù)字圖片,并查看我們模型對其的預(yù)測效果。
? ?首先我們還是加載數(shù)據(jù)集,和之前的代碼一樣,這里省略。然后我們加載模型:
from net import MyLeNet5# 調(diào)用net,將模型數(shù)據(jù)轉(zhuǎn)移到gpu
model = MyLeNet5().to(device)
model.load_state_dict(torch.load('./save_model/best_model.pth'))
? ? ?我們?nèi)〕鰷y試集中的前5長圖片做一個展示即可,完整代碼show.py如下:
import torch
from net import MyLeNet5
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImagedata_transform = transforms.Compose([transforms.ToTensor()
])# 加載訓(xùn)練數(shù)據(jù)集
train_dataset = datasets.MNIST(root='D:\\Jupyter\\dataset\\minst', # 下載路徑train=True, # 是訓(xùn)練集download=True, # 如果該路徑?jīng)]有該數(shù)據(jù)集,則進行下載transform=data_transform # 數(shù)據(jù)集轉(zhuǎn)換參數(shù)
)# 加載測試數(shù)據(jù)集
test_dataset = datasets.MNIST(root='D:\\Jupyter\\dataset\\minst', # 下載路徑train=False, # 是訓(xùn)練集download=True, # 如果該路徑?jīng)]有該數(shù)據(jù)集,則進行下載transform=data_transform # 數(shù)據(jù)集轉(zhuǎn)換參數(shù)
)# 判斷是否有g(shù)pu
device = "cuda" if torch.cuda.is_available() else "cpu"# 調(diào)用net,將模型數(shù)據(jù)轉(zhuǎn)移到gpu
model = MyLeNet5().to(device)
model.load_state_dict(torch.load('./save_model/best_model.pth'))# 獲取結(jié)果
classes = ["0","1","2","3","4","5","6","7","8","9"
]# 把tensor轉(zhuǎn)化為圖片,方便可視化
image = ToPILImage()# 進入驗證
for i in range(5):X, y = test_dataset[i][0], test_dataset[i][1] # X,y對應(yīng)第i張圖片和標(biāo)簽# image是ToPILImage的實例,將Pytorch張量轉(zhuǎn)換為PIL圖像,.show()方法會打開圖像查看器并顯示圖像image(X).show()'''unsqueeze 方法在指定的 dim 維度上擴展張量的維度。這里 dim=0,所以它會在第0維添加一個維度.例如,原來的 X 形狀是 (1, 28, 28),經(jīng)過 unsqueeze 處理后,形狀變?yōu)?(1, 1, 28, 28)。這樣做的目的是將單張圖像擴展成批次大小為1的形式,這樣模型可以接收單張圖像作為輸入。'''X = torch.unsqueeze(X, dim=0).float().to(device)with torch.no_grad():# 前向傳播獲得預(yù)測結(jié)果pred(由10個元素組成的張量)pred = model(X)print(pred)# 將預(yù)測值和標(biāo)簽轉(zhuǎn)化為對應(yīng)的數(shù)字分類結(jié)果,pred中的最大值視為預(yù)測分類predicted, actual = classes[torch.argmax(pred[0])], classes[y]print(f"predicted: {predicted}, actual: {actual}")
? ?運行結(jié)果如下,我們可以看到測試集中的前五張圖片分別是7,2,1,0,4,且我們的模型都能對其進行成功預(yù)測分類。
? ?預(yù)測結(jié)果均正確,如下:
寫在最后
? ? ? ? 本文介紹了如何使用PyTorch框架搭建卷積神經(jīng)網(wǎng)絡(luò)模型CNN。將CNN與DNN進行了類比。CNN中的卷積層與DNN的全連接層是平級關(guān)系。我們實現(xiàn)了LeNet-5的模型的搭建、模型訓(xùn)練、測試、網(wǎng)絡(luò)的復(fù)用、直觀查看數(shù)據(jù)集的圖片預(yù)測結(jié)果等,實現(xiàn)了機器學(xué)習(xí)領(lǐng)域的Hello world——手寫數(shù)字識別。在CNN原理中,讀者應(yīng)當(dāng)重點關(guān)注輸入輸出的尺寸關(guān)系,并可以對照LeNet-5結(jié)構(gòu)示意圖寫出對應(yīng)Pytorch代碼。至于模型訓(xùn)練和測試基本都是固定的代碼形式。
? ? ? 這篇文章到這里就結(jié)束了,后續(xù)我還會繼續(xù)更新深度學(xué)習(xí)的相關(guān)知識,另外近期我個人的研究方向涉及到圖神經(jīng)網(wǎng)絡(luò),回頭也會更新一些相關(guān)博客。如果讀者有相關(guān)建議或疑問也歡迎評論區(qū)與我共同探討,我一定知無不言??偨Y(jié)不易,還請讀者多多點贊關(guān)注支持!
?