網(wǎng)站建設(shè)效果好不好seo對(duì)網(wǎng)店推廣的作用
全文鏈接:https://tecdat.cn/?p=33566
生成對(duì)抗網(wǎng)絡(luò)(GAN)是一種神經(jīng)網(wǎng)絡(luò),可以生成類似于人類產(chǎn)生的材料,如圖像、音樂、語音或文本(點(diǎn)擊文末“閱讀原文”獲取完整代碼數(shù)據(jù))。
相關(guān)視頻
最近我們被客戶要求撰寫關(guān)于GAN生成對(duì)抗性神經(jīng)網(wǎng)絡(luò)的研究報(bào)告,包括一些圖形和統(tǒng)計(jì)輸出。
近年來,GAN一直是研究的熱門話題。Facebook的AI研究總監(jiān)Yann LeCun稱對(duì)抗訓(xùn)練是“過去10年中最有趣的機(jī)器學(xué)習(xí)領(lǐng)域的想法”。
本文將介紹以下內(nèi)容:
什么是生成模型以及它與判別模型的區(qū)別
GAN的結(jié)構(gòu)和訓(xùn)練方式
如何使用PyTorch構(gòu)建GAN
如何使用GPU和PyTorch訓(xùn)練GAN以實(shí)現(xiàn)實(shí)際應(yīng)用
什么是生成對(duì)抗網(wǎng)絡(luò)?
生成對(duì)抗網(wǎng)絡(luò)是一種可以學(xué)習(xí)模仿給定數(shù)據(jù)分布的機(jī)器學(xué)習(xí)系統(tǒng)。它們最早是由深度學(xué)習(xí)專家Ian Goodfellow及其同事在2014年的一篇NeurIPS論文中提出的。
GAN由兩個(gè)神經(jīng)網(wǎng)絡(luò)組成,一個(gè)網(wǎng)絡(luò)用于生成數(shù)據(jù),另一個(gè)網(wǎng)絡(luò)用于區(qū)分真實(shí)數(shù)據(jù)和假數(shù)據(jù)(因此模型具有"對(duì)抗"的性質(zhì))。雖然生成數(shù)據(jù)的結(jié)構(gòu)并不新鮮,但在圖像和視頻生成方面,GAN取得了令人印象深刻的成果,例如:
使用CycleGAN進(jìn)行風(fēng)格轉(zhuǎn)換,可以對(duì)圖像進(jìn)行多種令人信服的風(fēng)格轉(zhuǎn)換
利用StyleGAN生成人臉,如網(wǎng)站This Person Does Not Exist上所示
判別模型與生成模型
如果您學(xué)習(xí)過神經(jīng)網(wǎng)絡(luò),那么您接觸到的大多數(shù)應(yīng)用很可能是使用判別模型實(shí)現(xiàn)的。而生成對(duì)抗網(wǎng)絡(luò)屬于一類不同的模型,被稱為生成模型。
在訓(xùn)練過程中,您會(huì)使用一個(gè)算法來調(diào)整模型的參數(shù)。目標(biāo)是通過最小化損失函數(shù)使模型學(xué)習(xí)到給定輸入的輸出的概率分布。在訓(xùn)練階段之后,您可以使用該模型通過估計(jì)輸入最可能對(duì)應(yīng)的數(shù)字來對(duì)新的手寫數(shù)字圖像進(jìn)行分類,如下圖所示:
您可以將用于分類問題的判別模型想象成使用訓(xùn)練數(shù)據(jù)來學(xué)習(xí)類別之間邊界的區(qū)塊。然后,它們使用這些邊界來區(qū)分輸入并預(yù)測(cè)其類別。數(shù)學(xué)上來說,判別模型學(xué)習(xí)輸出y給定輸入x的條件概率P(y|x)。
除了神經(jīng)網(wǎng)絡(luò),其他結(jié)構(gòu)也可以用作判別模型,例如邏輯回歸模型和支持向量機(jī)(SVM)。
然而,生成模型(如GAN)被訓(xùn)練為描述數(shù)據(jù)集的生成方式,以概率模型的形式進(jìn)行。通過從生成模型中采樣,您可以生成新的數(shù)據(jù)。雖然判別模型常用于有監(jiān)督學(xué)習(xí),但生成模型通常與無標(biāo)簽的數(shù)據(jù)集一起使用,并可被視為一種無監(jiān)督學(xué)習(xí)的形式。
使用手寫數(shù)字?jǐn)?shù)據(jù)集,您可以訓(xùn)練一個(gè)生成模型來生成新的數(shù)字。在訓(xùn)練階段,您會(huì)使用某種算法來調(diào)整模型的參數(shù),以最小化損失函數(shù)并學(xué)習(xí)訓(xùn)練集的概率分布。然后,通過訓(xùn)練好的模型,您可以生成新的樣本,如下圖所示:
為了輸出新的樣本,生成模型通??紤]到一個(gè)隨機(jī)元素,該隨機(jī)元素影響模型生成的樣本。用于驅(qū)動(dòng)生成器的隨機(jī)樣本是從"潛在空間"中獲得的,在該空間中,向量表示一種壓縮形式的生成樣本。
與判別模型不同,生成模型學(xué)習(xí)輸入數(shù)據(jù)x的概率P(x),通過具有輸入數(shù)據(jù)分布,它們能夠生成新的數(shù)據(jù)實(shí)例。
盡管GAN近年來受到了廣泛關(guān)注,但它們并不是唯一可用作生成模型的架構(gòu)。除了GAN,還有其他各種生成模型架構(gòu),例如:
伯勞茲曼機(jī)(Boltzmann machines)
變分自編碼器(Variational autoencoders)
隱馬爾可夫模型(Hidden Markov models)
預(yù)測(cè)序列中的下一個(gè)詞的模型,如GPT-2
然而,由于其在圖像和視頻生成方面取得的令人興奮的結(jié)果,GAN最近引起了公眾的最大關(guān)注。
現(xiàn)在您已了解生成模型的基礎(chǔ)知識(shí),接下來將介紹GAN的工作原理和訓(xùn)練方法。
生成對(duì)抗網(wǎng)絡(luò)(GAN)的架構(gòu)
生成對(duì)抗網(wǎng)絡(luò)由兩個(gè)神經(jīng)網(wǎng)絡(luò)組成,一個(gè)稱為"生成器"(generator),另一個(gè)稱為"判別器"(discriminator)。
生成器的作用是估計(jì)真實(shí)樣本的概率分布,以提供類似真實(shí)數(shù)據(jù)的生成樣本。而判別器則被訓(xùn)練來估計(jì)給定樣本來自真實(shí)數(shù)據(jù)的概率,而不是由生成器提供的。
這些結(jié)構(gòu)被稱為生成對(duì)抗網(wǎng)絡(luò),因?yàn)樯善骱团袆e器被訓(xùn)練以相互競(jìng)爭(zhēng):生成器試圖在愚弄判別器方面變得更好,而判別器試圖在識(shí)別生成樣本方面變得更好。
為了理解GAN的訓(xùn)練過程,考慮一個(gè)示例,包含一個(gè)由二維樣本(x?,?x?)組成的數(shù)據(jù)集,其中?x? 在 0 到 2π 的區(qū)間內(nèi),x? = sin(x?),如下圖所示:
可以看到,這個(gè)數(shù)據(jù)集由位于正弦曲線上的點(diǎn)(x?,?x?)組成,具有非常特殊的分布。GAN的整體結(jié)構(gòu)用于生成類似數(shù)據(jù)集樣本的(x??,?x??)對(duì),如下圖所示:
生成器G接收來自潛在空間的隨機(jī)數(shù)據(jù),并且其作用是生成類似真實(shí)樣本的數(shù)據(jù)。在這個(gè)示例中,我們有一個(gè)二維的潛在空間,因此生成器接收隨機(jī)的(z?,?z?)對(duì),并要求將它們轉(zhuǎn)化為類似真實(shí)樣本的形式。
生成對(duì)抗網(wǎng)絡(luò)(GAN)
作為生成對(duì)抗網(wǎng)絡(luò)的初次實(shí)驗(yàn),你將實(shí)現(xiàn)前面一節(jié)中描述的示例。
要運(yùn)行這個(gè)示例,你需要使用PyTorch庫,可以通過Anaconda Python發(fā)行版和conda軟件包和環(huán)境管理系統(tǒng)來安裝。
首先,創(chuàng)建一個(gè)conda環(huán)境并激活它:
$ conda create --name gan
$ conda activate gan
當(dāng)你激活conda環(huán)境后,你的命令提示符會(huì)顯示環(huán)境的名稱,即gan
。然后你可以在該環(huán)境中安裝必要的包:
$ conda install -c pytorch pytorch=1.4.0
$ conda install matplotlib jupyter
由于PyTorch是一個(gè)非常活躍的開發(fā)框架,其API可能會(huì)在新版本中發(fā)生變化。為了確保示例代碼能夠運(yùn)行,你需要安裝特定的版本1.4.0。
除了PyTorch,你還將使用Matplotlib進(jìn)行繪圖,并在Jupyter Notebook中運(yùn)行交互式代碼。這并不是強(qiáng)制性的,但它有助于進(jìn)行機(jī)器學(xué)習(xí)項(xiàng)目的工作。
在打開Jupyter Notebook之前,你需要注冊(cè)conda環(huán)境gan
,以便可以將其作為內(nèi)核來創(chuàng)建Notebook。要做到這一點(diǎn),在激活gan
環(huán)境后,運(yùn)行以下命令:
$ python -m ipykernel install --user --name gan
現(xiàn)在你可以通過運(yùn)行jupyter notebook
來打開Jupyter Notebook。通過點(diǎn)擊“新建”然后選擇“gan”來創(chuàng)建一個(gè)新的Notebook。
在Notebook中,首先導(dǎo)入必要的庫:
import torch
from torch import nnimport math
import matplotlib.pyplot as plt
在這里,你使用torch
導(dǎo)入了PyTorch庫。你還導(dǎo)入nn
,為了能夠以更簡(jiǎn)潔的方式設(shè)置神經(jīng)網(wǎng)絡(luò)。然后你導(dǎo)入math
來獲取pi常數(shù)的值,并按照慣例導(dǎo)入Matplotlib繪圖工具為plt
。
為了使實(shí)驗(yàn)在任何機(jī)器上都能完全復(fù)現(xiàn),最好設(shè)置一個(gè)隨機(jī)生成器種子。在PyTorch中,可以通過運(yùn)行以下代碼來實(shí)現(xiàn):
torch.manual_seed(111)
數(shù)字111
代表用于初始化隨機(jī)數(shù)生成器的隨機(jī)種子,它用于初始化神經(jīng)網(wǎng)絡(luò)的權(quán)重。盡管實(shí)驗(yàn)具有隨機(jī)性,但只要使用相同的種子,它應(yīng)該產(chǎn)生相同的結(jié)果。
現(xiàn)在環(huán)境已經(jīng)設(shè)置好,可以準(zhǔn)備訓(xùn)練數(shù)據(jù)了。
準(zhǔn)備訓(xùn)練數(shù)據(jù)
訓(xùn)練數(shù)據(jù)由一對(duì)(x?,x?)組成,其中x?是x?在區(qū)間從0到2π上的正弦值。你可以按照以下方式實(shí)現(xiàn)它:
train_data_length = 1024train_set = [(train_data[i], train_labels[i]) for i in range(train_data_length)
]
在這里,你創(chuàng)建了一個(gè)包含1024
對(duì)(x?,x?)的訓(xùn)練集。在第2行,你初始化了train_data
,它是一個(gè)具有1024
行和2
列的張量,所有元素都為零。張量是一個(gè)類似于NumPy數(shù)組的多維數(shù)組。
在第3行,你使用train_data
的第一列來存儲(chǔ)在0
到2π
區(qū)間內(nèi)的隨機(jī)值。然后,在第4行,你計(jì)算了張量的第二列,即第一列的正弦值。
接下來,你需要一個(gè)標(biāo)簽張量,PyTorch的數(shù)據(jù)加載器需要使用它。由于GAN使用無監(jiān)督學(xué)習(xí)技術(shù),標(biāo)簽可以是任何值。畢竟,它們不會(huì)被使用。
在第5行,你創(chuàng)建了一個(gè)填充了零的train_labels
張量。最后,在第6到8行,你將train_set
創(chuàng)建為元組列表,其中每個(gè)元組代表train_data
和train_labels
的每一行,正如PyTorch的數(shù)據(jù)加載器所期望的那樣。
你可以通過繪制每個(gè)點(diǎn)(x?,x?)來查看訓(xùn)練數(shù)據(jù):
plt.plot(train_data[:, 0], train_data[:, 1], ".")
輸出應(yīng)該類似于以下圖形:
使用train_set
,您可以創(chuàng)建一個(gè)PyTorch數(shù)據(jù)加載器:
batch_size = 32)
在這里,您創(chuàng)建了一個(gè)名為train_loader
的數(shù)據(jù)加載器,它將對(duì)train_set
中的數(shù)據(jù)進(jìn)行洗牌,并返回大小為32的樣本批次,您將使用這些批次來訓(xùn)練神經(jīng)網(wǎng)絡(luò)。
設(shè)置訓(xùn)練數(shù)據(jù)后,您需要為判別器和生成器創(chuàng)建神經(jīng)網(wǎng)絡(luò),它們將組成GAN。在下一節(jié)中,您將實(shí)現(xiàn)判別器。
實(shí)現(xiàn)判別器
在PyTorch中,神經(jīng)網(wǎng)絡(luò)模型由繼承自nn.Module
的類表示,因此您需要定義一個(gè)類來創(chuàng)建判別器。
判別別器是一個(gè)具有二維輸入和一維輸出的模型。它將接收來自真實(shí)數(shù)據(jù)或生成器的樣本,并提供樣本屬于真實(shí)訓(xùn)練數(shù)據(jù)的概率。下面的代碼展示了如何創(chuàng)建判別器:
class Discriminator(nn.Module):def __init__(self):nn.Linear(64, 1),nn.Sigmoid(),)def forward(self, x):output = self.model(x)return output
您使用. __init __()
來構(gòu)建模型。首先,您需要調(diào)用super().__init __()
來運(yùn)行nn.Module
中的.__init __()
。您使用的判別器是在nn.Sequential()
中以順序方式定義的MLP神經(jīng)網(wǎng)絡(luò)。它具有以下特點(diǎn):
第5和第6行:輸入為二維,第一個(gè)隱藏層由256個(gè)神經(jīng)元組成,并使用ReLU激活函數(shù)。
第8、9、11和12行:第二個(gè)和第三個(gè)隱藏層分別由128個(gè)和64個(gè)神經(jīng)元組成,并使用ReLU激活函數(shù)。
第14和第15行:輸出由一個(gè)神經(jīng)元組成,并使用sigmoidal激活函數(shù)表示概率。
第7、10和13行:在第一個(gè)、第二個(gè)和第三個(gè)隱藏層之后,您使用dropout來避免過擬合。
最后,您使用.forward()
來描述如何計(jì)算模型的輸出。這里,x
表示模型的輸入,它是一個(gè)二維張量。在此實(shí)現(xiàn)中,通過將輸入x
饋送到您定義的模型中而不進(jìn)行任何其他處理來獲得輸出。
聲明判別器類后,您應(yīng)該實(shí)例化一個(gè)Discriminator
對(duì)象:
discriminator = Discriminator()
discriminator
代表您定義的神經(jīng)網(wǎng)絡(luò)的一個(gè)實(shí)例,準(zhǔn)備好進(jìn)行訓(xùn)練。但是,在實(shí)現(xiàn)訓(xùn)練循環(huán)之前,您的GAN還需要一個(gè)生成器。您將在下一節(jié)中實(shí)現(xiàn)一個(gè)生成器。
實(shí)現(xiàn)生成器
在生成對(duì)抗網(wǎng)絡(luò)中,生成器是一個(gè)以潛在空間中的樣本作為輸入,并生成類似于訓(xùn)練集中數(shù)據(jù)的模型。在這種情況下,它是一個(gè)具有二維輸入的模型,將接收隨機(jī)點(diǎn)(z?,z?),并提供類似于訓(xùn)練數(shù)據(jù)中的(x??,x??)點(diǎn)的二維輸出。
實(shí)現(xiàn)類似于您為判別器所做的操作。首先,您必須創(chuàng)建一個(gè)從nn.Module
繼承并定義神經(jīng)網(wǎng)絡(luò)架構(gòu)的Generator
類,然后需要實(shí)例化一個(gè)Generator
對(duì)象:
class Generator(nn.Module):def __init__(self):super().__init__()generator = Generator()
在這里,generator
代表生成器神經(jīng)網(wǎng)絡(luò)。它由兩個(gè)具有16個(gè)和32個(gè)神經(jīng)元的隱藏層組成,兩者都使用ReLU激活函數(shù),以及一個(gè)具有2個(gè)神經(jīng)元的線性激活層作為輸出。這樣,輸出將由一個(gè)包含兩個(gè)元素的向量組成,可以是從負(fù)無窮大到正無窮大的任何值,這些值將表示(x??,x??)。
現(xiàn)在,您已定義了判別器和生成器的模型,可以開始進(jìn)行訓(xùn)練了!
訓(xùn)練模型
在訓(xùn)練模型之前,您需要設(shè)置一些參數(shù)來在訓(xùn)練過程中使用:
lr = 0.001
num_epochs = 300
在這里,您設(shè)置了以下參數(shù):
第1行設(shè)置學(xué)習(xí)率(
lr
),您將使用它來調(diào)整網(wǎng)絡(luò)權(quán)重。第2行設(shè)置了周期數(shù)(
num_epochs
),定義了對(duì)整個(gè)訓(xùn)練集進(jìn)行訓(xùn)練的重復(fù)次數(shù)。第3行將變量
loss_function
賦值為二進(jìn)制交叉熵函數(shù)BCELoss()
,這是您將用于訓(xùn)練模型的損失函數(shù)。
二進(jìn)制交叉熵函數(shù)是訓(xùn)練判別器的適用損失函數(shù),因?yàn)樗紤]了二元分類任務(wù)。它也適用于訓(xùn)練生成器,因?yàn)樗鼘⑵漭敵鲳佀徒o判別器,后者提供一個(gè)二進(jìn)制的可觀測(cè)輸出。
PyTorch在torch.optim
中實(shí)現(xiàn)了各種權(quán)重更新規(guī)則用于模型訓(xùn)練。您將使用Adam算法來訓(xùn)練判別器和生成器模型。要使用torch.optim
創(chuàng)建優(yōu)化器,請(qǐng)運(yùn)行以下代碼:
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
最后,你需要實(shí)現(xiàn)一個(gè)訓(xùn)練循環(huán),在該循環(huán)中,將訓(xùn)練樣本輸入模型,并更新其權(quán)重以最小化損失函數(shù):
for epoch in range(num_epochs):for n, (real_samples, _) in enumerate(train_loader):# 訓(xùn)練判別器的數(shù)據(jù)real_samples_labels = torch.ones((batch_size, 1))# 訓(xùn)練判別器discriminator.zero_grad()# 訓(xùn)練生成器的數(shù)據(jù)latent_space_samples = torch.randn((batch_size, 2))# 訓(xùn)練生成器generator.zero_grad()# 顯示損失if epoch % 10 == 0 and n == batch_size - 1:
對(duì)于生成對(duì)抗網(wǎng)絡(luò)(GANs),您需要在每個(gè)訓(xùn)練迭代中更新判別器和生成器的參數(shù)。與所有神經(jīng)網(wǎng)絡(luò)一樣,訓(xùn)練過程包括兩個(gè)循環(huán),一個(gè)用于訓(xùn)練周期,另一個(gè)用于每個(gè)周期的批處理。在內(nèi)部循環(huán)中,您開始準(zhǔn)備用于訓(xùn)練判別器的數(shù)據(jù):
第2行:?從數(shù)據(jù)加載器中獲取當(dāng)前批次的真實(shí)樣本,并將其賦值給
real_samples
。請(qǐng)注意,張量的第一個(gè)維度具有與batch_size
相等的元素?cái)?shù)量。這是在PyTorch中組織數(shù)據(jù)的標(biāo)準(zhǔn)方式,張量的每一行表示批次中的一個(gè)樣本。第4行:?使用
torch.ones()
為真實(shí)樣本創(chuàng)建標(biāo)簽,并將標(biāo)簽賦給real_samples_labels
。第5和第6行:?通過在
latent_space_samples
中存儲(chǔ)隨機(jī)數(shù)據(jù),創(chuàng)建生成的樣本,然后將其輸入生成器以獲得generated_samples
。第7行:?使用
torch.zeros()
將標(biāo)簽值0
分配給生成的樣本的標(biāo)簽,然后將標(biāo)簽存儲(chǔ)在generated_samples_labels
中。第8到11行:?將真實(shí)樣本和生成的樣本以及標(biāo)簽連接起來,并將其存儲(chǔ)在
all_samples
和all_samples_labels
中,您將使用它們來訓(xùn)練判別器。
接下來,在第14到19行,您訓(xùn)練了判別器:
第14行:?在PyTorch中,每個(gè)訓(xùn)練步驟都需要清除梯度,以避免積累。您可以使用
.zero_grad()
來實(shí)現(xiàn)這一點(diǎn)。第15行:?您使用訓(xùn)練數(shù)據(jù)
all_samples
計(jì)算判別器的輸出。第16和17行:?您使用模型的輸出
output_discriminator
和標(biāo)簽all_samples_labels
來計(jì)算損失函數(shù)。第18行:?您使用
loss_discriminator.backward()
計(jì)算梯度以更新權(quán)重。第19行:?您通過調(diào)用
optimizer_discriminator.step()
來更新判別器的權(quán)重。
接下來,在第22行,您準(zhǔn)備用于訓(xùn)練生成器的數(shù)據(jù)。您將隨機(jī)數(shù)據(jù)存儲(chǔ)在latent_space_samples
中,行數(shù)與batch_size
相等。由于您將二維數(shù)據(jù)作為輸入提供給生成器,因此使用了兩列。
然后,在第25到32行,您訓(xùn)練了生成器:
第25行:?使用
.zero_grad()
清除梯度。第26行:?將
latent_space_samples
提供給生成器,并將其輸出存儲(chǔ)在generated_samples
中。第27行:?將生成器的輸出輸入判別器,并將其輸出存儲(chǔ)在
output_discriminator_generated
中,您將使用其作為整個(gè)模型的輸出。第28到30行:?使用分類系統(tǒng)的輸出
output_discriminator_generated
和標(biāo)簽real_samples_labels
計(jì)算損失函數(shù),這些標(biāo)簽都等于1
。第31和32行:?計(jì)算梯度并更新生成器的權(quán)重。請(qǐng)記住,當(dāng)訓(xùn)練生成器時(shí),保持判別器權(quán)重凍結(jié),因?yàn)槟鷦?chuàng)建了
optimizer_generator
,其第一個(gè)參數(shù)等于generator.parameters()
。
最后,在第35到37行,您顯示了每十個(gè)周期結(jié)束時(shí)判別器和生成器損失函數(shù)的值。
由于此示例中使用的模型參數(shù)較少,訓(xùn)練將在幾分鐘內(nèi)完成。在接下來的部分中,您將使用訓(xùn)練的GAN生成一些樣本。
檢查GAN生成的樣本
生成對(duì)抗網(wǎng)絡(luò)被設(shè)計(jì)用于生成數(shù)據(jù)。因此,在訓(xùn)練過程結(jié)束后,您可以從潛在空間中獲取一些隨機(jī)樣本,并將它們提供給生成器以獲得一些生成的樣本:
latent_space_samples = torch.randn(100, 2)
然后,您可以繪制生成的樣本,并檢查它們是否類似于訓(xùn)練數(shù)據(jù)。在繪制generated_samples
數(shù)據(jù)之前,您需要使用.detach()
從PyTorch計(jì)算圖中返回一個(gè)張量,然后使用它來計(jì)算梯度:
generated_samples = generated_samples.detach()
輸出應(yīng)類似于以下圖像:
您可以看到,生成的數(shù)據(jù)分布類似于真實(shí)數(shù)據(jù)。通過使用固定的潛在空間樣本張量,并在訓(xùn)練過程的每個(gè)周期結(jié)束時(shí)將其提供給生成器,您可以可視化訓(xùn)練的演變。
點(diǎn)擊標(biāo)題查閱往期內(nèi)容
【視頻】Python用LSTM長(zhǎng)短期記憶神經(jīng)網(wǎng)絡(luò)對(duì)不穩(wěn)定降雨量時(shí)間序列進(jìn)行預(yù)測(cè)分析|數(shù)據(jù)分享
左右滑動(dòng)查看更多
01
02
03
04
手寫數(shù)字圖像與GAN
生成對(duì)抗網(wǎng)絡(luò)可以生成高維樣本,例如圖像。在此示例中,您將使用GAN生成手寫數(shù)字圖像。為此,您將使用包含手寫數(shù)字的MNIST數(shù)據(jù)集,該數(shù)據(jù)集已包含在torchvision包中。
首先,您需要在已激活的gan
?conda環(huán)境中安裝torchvision
:
$ conda install -c pytorch torchvision=0.5.0
與前面一樣,您使用特定版本的torchvision
來確保示例代碼可以運(yùn)行,就像您在pytorch
上所做的一樣。設(shè)置好環(huán)境后,您可以在Jupyter Notebook中開始實(shí)現(xiàn)模型。
與之前的示例一樣,首先導(dǎo)入必要的庫:
import torch
from torch import nn
除了之前導(dǎo)入的庫外,您還將需要torchvision
和transforms
來獲取訓(xùn)練數(shù)據(jù)并執(zhí)行圖像轉(zhuǎn)換。
同樣,設(shè)置隨機(jī)生成器種子以便能夠復(fù)制實(shí)驗(yàn):
torch.manual_seed(111)
由于此示例在訓(xùn)練集中使用圖像,所以模型需要更復(fù)雜,并且具有更多的參數(shù)。這使得訓(xùn)練過程變慢,當(dāng)在CPU上運(yùn)行時(shí),每個(gè)周期需要大約兩分鐘。您需要大約50個(gè)周期才能獲得相關(guān)結(jié)果,因此在使用CPU時(shí)的總訓(xùn)練時(shí)間約為100分鐘。
為了減少訓(xùn)練時(shí)間,如果您有可用的GPU,可以使用它來訓(xùn)練模型。但是,您需要手動(dòng)將張量和模型移動(dòng)到GPU上,以便在訓(xùn)練過程中使用它們。
您可以通過創(chuàng)建一個(gè)指向CPU或(如果有)GPU的device
對(duì)象來確保您的代碼將在任何一種設(shè)置上運(yùn)行:
device = ""device = torch.device("cpu")
稍后,您將使用此device
在可用的情況下使用GPU來設(shè)置張量和模型的創(chuàng)建位置。
現(xiàn)在基本環(huán)境已經(jīng)設(shè)置好了,您可以準(zhǔn)備訓(xùn)練數(shù)據(jù)。
準(zhǔn)備訓(xùn)練數(shù)據(jù)
MNIST數(shù)據(jù)集由28×28像素的灰度手寫數(shù)字圖像組成,范圍從0到9。為了在PyTorch中使用它們,您需要進(jìn)行一些轉(zhuǎn)換。為此,您定義了一個(gè)名為transform
的函數(shù)來加載數(shù)據(jù)時(shí)使用:
transform = transforms.Compose()
該函數(shù)分為兩個(gè)部分:
transforms.ToTensor()
將數(shù)據(jù)轉(zhuǎn)換為PyTorch張量。transforms.Normalize()
轉(zhuǎn)換張量系數(shù)的范圍。
由transforms.ToTensor()
產(chǎn)生的原始系數(shù)范圍從0到1,而且由于圖像背景是黑色,當(dāng)使用此范圍表示時(shí),大多數(shù)系數(shù)都等于0。
transforms.Normalize()
通過從原始系數(shù)中減去0.5
并將結(jié)果除以0.5
,將系數(shù)的范圍更改為-1到1。通過這種轉(zhuǎn)換,輸入樣本中為0的元素?cái)?shù)量大大減少,有助于訓(xùn)練模型。
transforms.Normalize()
的參數(shù)是兩個(gè)元組(M?, ..., M?)
和(S?, ..., S?)
,其中n
表示圖像的通道數(shù)量。MNIST數(shù)據(jù)集中的灰度圖像只有一個(gè)通道,因此元組只有一個(gè)值。因此,對(duì)于圖像的每個(gè)通道i
,transforms.Normalize()
從系數(shù)中減去M?
并將結(jié)果除以S?
。
現(xiàn)在,您可以使用torchvision.datasets.MNIST
加載訓(xùn)練數(shù)據(jù),并使用transform
進(jìn)行轉(zhuǎn)換:
train_set = torchvision.datasets.MNIST()
參數(shù)download=True
確保您第一次運(yùn)行上述代碼時(shí),MNIST數(shù)據(jù)集將會(huì)被下載并存儲(chǔ)在當(dāng)前目錄中,如參數(shù)root
所指示的位置。
現(xiàn)在您已經(jīng)創(chuàng)建了train_set
,可以像之前一樣創(chuàng)建數(shù)據(jù)加載器:
batch_size = 32
train_loader = torch.utils.data.DataLoader()
您可以使用Matplotlib繪制一些訓(xùn)練數(shù)據(jù)的樣本。為了改善可視化效果,您可以使用cmap=gray_r
來反轉(zhuǎn)顏色映射,并以黑色數(shù)字在白色背景上繪制:
real_samples, mnist_labels = next(iter(train_loader))
for i in range(16):
輸出應(yīng)該類似于以下內(nèi)容:
如您所見,有不同的手寫風(fēng)格的數(shù)字。隨著GAN學(xué)習(xí)數(shù)據(jù)的分布,它還會(huì)生成具有不同手寫風(fēng)格的數(shù)字。
現(xiàn)在您已經(jīng)準(zhǔn)備好了訓(xùn)練數(shù)據(jù),可以實(shí)現(xiàn)判別器和生成器模型。
實(shí)現(xiàn)判別器和生成器
本例中判別器是一個(gè)MLP神經(jīng)網(wǎng)絡(luò),它接收一個(gè)28 × 28像素的圖像,并提供圖像屬于真實(shí)訓(xùn)練數(shù)據(jù)的概率。
您可以使用以下代碼定義模型:
class Discriminator(nn.Module):def __init__(self):def forward(self, x):return output
為了將圖像系數(shù)輸入到MLP神經(jīng)網(wǎng)絡(luò)中,可以將它們進(jìn)行向量化,使得神經(jīng)網(wǎng)絡(luò)接收具有784
個(gè)系數(shù)的向量。
矢量化發(fā)生在.forward()
的第一行,因?yàn)檎{(diào)用x.view()
可以轉(zhuǎn)換輸入張量的形狀。在這種情況下,輸入x
的原始形狀是32×1×28×28,其中32是您設(shè)置的批量大小。轉(zhuǎn)換后,x
的形狀變?yōu)?2×784,每行表示訓(xùn)練集中圖像的系數(shù)。
要使用GPU運(yùn)行判別器模型,您必須實(shí)例化它并使用.to()
將其發(fā)送到GPU。要在有可用GPU時(shí)使用GPU,可以將模型發(fā)送到先前創(chuàng)建的device
對(duì)象:
discriminator = Discriminator().to(device=device)
由于生成器將生成更復(fù)雜的數(shù)據(jù),因此需要增加來自潛在空間的輸入維數(shù)。在這種情況下,生成器將接收一個(gè)100維的輸入,并提供一個(gè)由784個(gè)系數(shù)組成的輸出,這些系數(shù)將以28×28的張量表示為圖像。
下面是完整的生成器模型代碼:
class Generator(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(
在第12行,使用雙曲正切函數(shù)Tanh()
作為輸出層的激活函數(shù),因?yàn)檩敵鱿禂?shù)應(yīng)該在-1到1的區(qū)間內(nèi)。在第20行,實(shí)例化生成器并將其發(fā)送到device
以使用可用的GPU。
現(xiàn)在您已經(jīng)定義了模型,可以使用訓(xùn)練數(shù)據(jù)對(duì)它們進(jìn)行訓(xùn)練。
訓(xùn)練模型
要訓(xùn)練模型,需要定義訓(xùn)練參數(shù)和優(yōu)化器,就像在之前的示例中所做的那樣:
lr = 0.0001
num_epochs = 50
loss_function = nn.BCELoss()optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
為了獲得更好的結(jié)果,將學(xué)習(xí)率從先前的示例中降低。還將將epoch數(shù)設(shè)置為50
,以減少訓(xùn)練時(shí)間。
訓(xùn)練循環(huán)與之前的示例非常相似。在突出顯示的行中,將訓(xùn)練數(shù)據(jù)發(fā)送到device
以在有GPU可用時(shí)使用:
for epoch in range(num_epochs):for n, (real_samples, mnist_labels) in enumerate(train_loader):e, 100)).to(device=device)loss_generator = loss_function(output_discriminator_generated, real_samples_labels)loss_generator.backward()optimizer_generator.step().: {loss_generator}")
某些張量不需要使用device
顯式地發(fā)送到GPU。這適用于第11行中的generated_samples
,它將已經(jīng)被發(fā)送到可用GPU,因?yàn)?code>latent_space_samples和generator
先前已被發(fā)送到GPU。
由于此示例具有更復(fù)雜的模型,訓(xùn)練可能需要更長(zhǎng)時(shí)間。訓(xùn)練完成后,您可以通過生成一些手寫數(shù)字樣本來檢查結(jié)果。
檢查GAN生成的樣本
要生成手寫數(shù)字,您需要從潛在空間中隨機(jī)采樣一些樣本并將其提供給生成器:
latent_space_samples = torch.randn(batch_size, 100).to(device=device)
要繪制generated_samples
,您需要將數(shù)據(jù)移回GPU上運(yùn)行時(shí),以便在使用Matplotlib繪制數(shù)據(jù)之前,可以簡(jiǎn)單地調(diào)用.cpu()
。與之前一樣,還需要在使用Matplotlib繪制數(shù)據(jù)之前調(diào)用.detach()
:
generated_samples = generated_samples.cpu().detach()
for i in range(16):
輸出應(yīng)該是類似訓(xùn)練數(shù)據(jù)的數(shù)字,如以下圖片所示:
?經(jīng)過50個(gè)epoch的訓(xùn)練后,生成了一些類似真實(shí)數(shù)字的生成數(shù)字。通過增加訓(xùn)練epoch次數(shù),可以改善結(jié)果。
與之前的示例一樣,通過在訓(xùn)練過程的每個(gè)周期結(jié)束時(shí)使用固定的潛在空間樣本張量并將其提供給生成器,可以可視化訓(xùn)練的演變。
總結(jié)
您已經(jīng)學(xué)會(huì)了如何實(shí)現(xiàn)自己的生成對(duì)抗網(wǎng)絡(luò)。在深入探討生成手寫數(shù)字圖像的實(shí)際應(yīng)用之前,您首先通過一個(gè)簡(jiǎn)單的示例了解了GAN的結(jié)構(gòu)。
您看到,盡管GAN的復(fù)雜性很高,但像PyTorch這樣的機(jī)器學(xué)習(xí)框架通過提供自動(dòng)微分和簡(jiǎn)便的GPU設(shè)置,使其實(shí)現(xiàn)更加簡(jiǎn)單直觀。
在本文中,您學(xué)到了:
判別模型和生成模型的區(qū)別
如何結(jié)構(gòu)化和訓(xùn)練生成對(duì)抗網(wǎng)絡(luò)
如何使用PyTorch等工具和GPU來實(shí)現(xiàn)和訓(xùn)練GAN模型
GAN是一個(gè)非常活躍的研究課題,近年來提出了幾個(gè)令人興奮的應(yīng)用。如果您對(duì)此主題感興趣,請(qǐng)密切關(guān)注技術(shù)和科學(xué)文獻(xiàn),以獲取新的應(yīng)用想法。
點(diǎn)擊文末“閱讀原文”
獲取全文完整代碼數(shù)據(jù)資料。
本文選自《Python用GAN生成對(duì)抗性神經(jīng)網(wǎng)絡(luò)判別模型擬合多維數(shù)組、分類識(shí)別手寫數(shù)字圖像可視化》。
點(diǎn)擊標(biāo)題查閱往期內(nèi)容
PYTHON TENSORFLOW 2二維卷積神經(jīng)網(wǎng)絡(luò)CNN對(duì)圖像物體識(shí)別混淆矩陣評(píng)估|數(shù)據(jù)分享
R語言深度學(xué)習(xí)卷積神經(jīng)網(wǎng)絡(luò) (CNN)對(duì) CIFAR 圖像進(jìn)行分類:訓(xùn)練與結(jié)果評(píng)估可視化
R語言KERAS深度學(xué)習(xí)CNN卷積神經(jīng)網(wǎng)絡(luò)分類識(shí)別手寫數(shù)字圖像數(shù)據(jù)(MNIST)
MATLAB中用BP神經(jīng)網(wǎng)絡(luò)預(yù)測(cè)人體脂肪百分比數(shù)據(jù)
Python中用PyTorch機(jī)器學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)分類預(yù)測(cè)銀行客戶流失模型
R語言實(shí)現(xiàn)CNN(卷積神經(jīng)網(wǎng)絡(luò))模型進(jìn)行回歸數(shù)據(jù)分析
SAS使用鳶尾花(iris)數(shù)據(jù)集訓(xùn)練人工神經(jīng)網(wǎng)絡(luò)(ANN)模型
【視頻】R語言實(shí)現(xiàn)CNN(卷積神經(jīng)網(wǎng)絡(luò))模型進(jìn)行回歸數(shù)據(jù)分析
Python使用神經(jīng)網(wǎng)絡(luò)進(jìn)行簡(jiǎn)單文本分類
R語言用神經(jīng)網(wǎng)絡(luò)改進(jìn)Nelson-Siegel模型擬合收益率曲線分析
R語言基于遞歸神經(jīng)網(wǎng)絡(luò)RNN的溫度時(shí)間序列預(yù)測(cè)
R語言神經(jīng)網(wǎng)絡(luò)模型預(yù)測(cè)車輛數(shù)量時(shí)間序列
R語言中的BP神經(jīng)網(wǎng)絡(luò)模型分析學(xué)生成績(jī)
matlab使用長(zhǎng)短期記憶(LSTM)神經(jīng)網(wǎng)絡(luò)對(duì)序列數(shù)據(jù)進(jìn)行分類
R語言實(shí)現(xiàn)擬合神經(jīng)網(wǎng)絡(luò)預(yù)測(cè)和結(jié)果可視化
用R語言實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)預(yù)測(cè)股票實(shí)例
使用PYTHON中KERAS的LSTM遞歸神經(jīng)網(wǎng)絡(luò)進(jìn)行時(shí)間序列預(yù)測(cè)
python用于NLP的seq2seq模型實(shí)例:用Keras實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)機(jī)器翻譯
用于NLP的Python:使用Keras的多標(biāo)簽文本LSTM神經(jīng)網(wǎng)絡(luò)分類