深圳 營(yíng)銷型網(wǎng)站建設(shè)百度引擎搜索
一、CIFAR-100數(shù)據(jù)集介紹
CIFAR-100(Canadian Institute for Advanced Research - 100 classes)是一個(gè)經(jīng)典的圖像分類數(shù)據(jù)集,用于計(jì)算機(jī)視覺(jué)領(lǐng)域的研究和算法測(cè)試。它是CIFAR-10數(shù)據(jù)集的擴(kuò)展版本,包含了更多的類別,用于更具挑戰(zhàn)性的任務(wù)。
CIFAR-100包含了100個(gè)不同的類別,每個(gè)類別都包含600張32x32像素的彩色圖像。
這100個(gè)類別被劃分為20個(gè)大類別,每個(gè)大類別包含5個(gè)小類別。這個(gè)層次結(jié)構(gòu)使得數(shù)據(jù)集更加豐富,包含了各種各樣的對(duì)象和場(chǎng)景。每張圖像的大小是32x32像素,包含RGB三個(gè)通道。
用途: CIFAR-100常被用于評(píng)估圖像分類算法的性能。由于圖像分辨率相對(duì)較低,它在實(shí)際中可能不太適用于一些復(fù)雜的計(jì)算機(jī)視覺(jué)任務(wù),但對(duì)于學(xué)術(shù)研究和算法開(kāi)發(fā)而言是一個(gè)常見(jiàn)的基準(zhǔn)數(shù)據(jù)集。
二、下載并加載CIFAR-100數(shù)據(jù)集
import torch
from torch.utils.data import Dataset,DataLoader
import torchvision
import torchvision.transforms as transformsdef get_train_loader(mean, std, batch_size=16, num_workers=2, shuffle=True):transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize(mean, std)])cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True,transform=transform_train)cifar100_training_loader = DataLoader(cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)return cifar100_training_loaderdef get_val_loader(mean, std, batch_size=16, num_workers=2, shuffle=True):transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)cifar100_test_loader = DataLoader(cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)return cifar100_test_loader
這里我們采用的是torchvision下載CIFAR-100數(shù)據(jù)集并將其保存到指定的路徑,定義這兩個(gè)函數(shù) get_train_loader 和?get_val_loader?分別用于獲取訓(xùn)練集和驗(yàn)證集的數(shù)據(jù)加載器,并進(jìn)行了預(yù)處理和增強(qiáng)的操作。
三、檢測(cè)數(shù)據(jù)加載情況
博主曾經(jīng)在這上面吃過(guò)很多的虧,一般我們遇到維度不匹配的情況,通常會(huì)認(rèn)為是網(wǎng)絡(luò)的問(wèn)題,但我會(huì)告訴你也有可能是數(shù)據(jù)加載的部分,這種開(kāi)源數(shù)據(jù)集還好,我們項(xiàng)目上用的是自制的數(shù)據(jù)集,它的圖片可能真的就是有些問(wèn)題,比如你明明是用PIL加載圖片,按理來(lái)說(shuō)應(yīng)該就是三通道無(wú)疑才對(duì),但事實(shí)是就是存在通道為1的情況。
所以,為了讓我們具備嚴(yán)謹(jǐn)?shù)墓こ棠芰?#xff0c;為將來(lái)自己的項(xiàng)目打下基礎(chǔ),哪怕是開(kāi)源數(shù)據(jù)集,我們也要進(jìn)行測(cè)試。
一般來(lái)說(shuō),主要看到就是它的維度是否是正確的,還有它是否能夠正確的顯示。
在上面我們進(jìn)行預(yù)處理操作,所以應(yīng)該先進(jìn)行反歸一化:
def denormalize(tensor, mean, std):"""反歸一化操作,將歸一化后的張量轉(zhuǎn)換回原始范圍."""if not torch.is_tensor(tensor):raise TypeError("Input should be a torch tensor.")for t, m, s in zip(tensor, mean, std):t.mul_(s).add_(m)return tensor
而要看如何正常的顯示,我們當(dāng)然不希望單張的顯示,這樣似乎太慢了,所以這里我們按照批量大小進(jìn)行顯示:
def show_batch(images, labels):import matplotlibmatplotlib.use('TkAgg')images = denormalize(images, mean, std)img_grid = make_grid(images, nrow=4, padding=10, normalize=True)plt.imshow(img_grid.permute(1, 2, 0))plt.title(f"Labels: {labels}")plt.show()
測(cè)試代碼:
if __name__=="__main__":import matplotlib.pyplot as pltfrom torchvision.utils import make_gridCIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)def denormalize(tensor, mean, std):"""反歸一化操作,將歸一化后的張量轉(zhuǎn)換回原始范圍."""if not torch.is_tensor(tensor):raise TypeError("Input should be a torch tensor.")for t, m, s in zip(tensor, mean, std):t.mul_(s).add_(m)return tensormean = CIFAR100_TRAIN_MEANstd = CIFAR100_TRAIN_STDtest_loader = get_val_loader(mean, std, batch_size=16, num_workers=2, shuffle=False)def show_batch(images, labels):import matplotlibmatplotlib.use('TkAgg')images = denormalize(images, mean, std)img_grid = make_grid(images, nrow=4, padding=10, normalize=True)plt.imshow(img_grid.permute(1, 2, 0))plt.title(f"Labels: {labels}")plt.show()for images, labels in test_loader:show_batch(images, labels)# print(images.size(), labels)
最后兩行就是圖片批量顯示與維度檢測(cè)的測(cè)試,這里最好是單獨(dú)的測(cè)試,即兩行中一行注釋,一行正常運(yùn)行。
四、自定義CIFAR-100的dataset類
dataset類的以下幾個(gè)要點(diǎn):
- dataset類需要繼承import torch.utils.data.dataset。
- dataset的作用是將任意格式的數(shù)據(jù),通過(guò)讀取、預(yù)處理或數(shù)據(jù)增強(qiáng)后以tensor的形式輸出。其中任意格式的數(shù)據(jù)指可能是以文件夾名作為類別的形式、或以txt文件存儲(chǔ)圖片地址的形式。而輸出則指的是經(jīng)過(guò)處理后的一個(gè) batch的tensor格式數(shù)據(jù)和對(duì)應(yīng)標(biāo)簽。
- dataset類需要重寫的主要有三個(gè)函數(shù)要完成:__init__函數(shù)、__len__函數(shù)和__getitem__函數(shù)。
__init__(self, ...) 函數(shù):初始化數(shù)據(jù)集。在這里,你通常會(huì)加載數(shù)據(jù),設(shè)置轉(zhuǎn)換(transformations)等。這個(gè)函數(shù)在數(shù)據(jù)集創(chuàng)建時(shí)調(diào)用。
__len__(self)函數(shù):返回?cái)?shù)據(jù)集的大小,即數(shù)據(jù)集中樣本的數(shù)量。這個(gè)函數(shù)在調(diào)用len(dataset) 時(shí)調(diào)用。
__getitem__(self,index)函數(shù):根據(jù)給定的索引返回?cái)?shù)據(jù)集中的一個(gè)樣本。這個(gè)函數(shù)允許你通過(guò)索引訪問(wèn)數(shù)據(jù)集中的單個(gè)樣本,以便用于模型的訓(xùn)練和評(píng)估。
import os
import pickle
import numpy as npfrom torch.utils.data import Dataset,DataLoaderclass CIFAR100Dataset(Dataset):def __init__(self, path, transform=None, train=False):if train:sub_path = 'train'else:sub_path = 'test'with open(os.path.join(path, sub_path), 'rb') as cifar100:self.data = pickle.load(cifar100, encoding='bytes')self.transform = transformdef __len__(self):return len(self.data['fine_labels'.encode()])def __getitem__(self, index):label = self.data['fine_labels'.encode()][index]r = self.data['data'.encode()][index, :1024].reshape(32, 32)g = self.data['data'.encode()][index, 1024:2048].reshape(32, 32)b = self.data['data'.encode()][index, 2048:].reshape(32, 32)image = np.dstack((r, g, b))if self.transform:image = self.transform(image)return image, label
測(cè)試代碼:
if __name__=="__main__":mean = CIFAR100_TRAIN_MEANstd = CIFAR100_TRAIN_STDtransform_train = transforms.Compose([transforms.ToPILImage(),transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize(mean, std)])train_dataset = CIFAR100Dataset(path='./data/cifar-100-python', transform=transform_train)train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)for images, labels in train_loader:show_batch(images, labels)# print(images.size(), labels)
附錄
本章節(jié)源碼
import torch
from torch.utils.data import Dataset,DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import os
import pickle
import numpy as npCIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)__all__ = ["get_train_loader", "get_val_loader", "CIFAR100Dataset"]class CIFAR100Dataset(Dataset):def __init__(self, path, transform=None, train=False):if train:sub_path = 'train'else:sub_path = 'test'with open(os.path.join(path, sub_path), 'rb') as cifar100:self.data = pickle.load(cifar100, encoding='bytes')self.transform = transformdef __len__(self):return len(self.data['fine_labels'.encode()])def __getitem__(self, index):label = self.data['fine_labels'.encode()][index]r = self.data['data'.encode()][index, :1024].reshape(32, 32)g = self.data['data'.encode()][index, 1024:2048].reshape(32, 32)b = self.data['data'.encode()][index, 2048:].reshape(32, 32)image = np.dstack((r, g, b))if self.transform:image = self.transform(image)return image, labelclass CIFAR100Test(Dataset):def __init__(self, path, transform=None):with open(os.path.join(path, 'test'), 'rb') as cifar100:self.data = pickle.load(cifar100, encoding='bytes')self.transform = transformdef __len__(self):return len(self.data['data'.encode()])def __getitem__(self, index):label = self.data['fine_labels'.encode()][index]r = self.data['data'.encode()][index, :1024].reshape(32, 32)g = self.data['data'.encode()][index, 1024:2048].reshape(32, 32)b = self.data['data'.encode()][index, 2048:].reshape(32, 32)image = np.dstack((r, g, b))if self.transform:image = self.transform(image)return image, labeldef get_train_loader(mean, std, batch_size=16, num_workers=2, shuffle=True):transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize(mean, std)])cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True,transform=transform_train)cifar100_training_loader = DataLoader(cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)return cifar100_training_loaderdef get_val_loader(mean, std, batch_size=16, num_workers=2, shuffle=True):transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)cifar100_test_loader = DataLoader(cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)return cifar100_test_loaderdef show_batch(images, labels):import matplotlibmatplotlib.use('TkAgg')images = denormalize(images, CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)img_grid = make_grid(images, nrow=4, padding=10, normalize=True)plt.imshow(img_grid.permute(1, 2, 0))plt.title(f"Labels: {labels}")plt.show()def denormalize(tensor, mean, std):"""反歸一化操作,將歸一化后的張量轉(zhuǎn)換回原始范圍."""if not torch.is_tensor(tensor):raise TypeError("Input should be a torch tensor.")for t, m, s in zip(tensor, mean, std):t.mul_(s).add_(m)return tensordef main1():test_loader = get_val_loader(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD, batch_size=16, num_workers=2, shuffle=False)for images, labels in test_loader:show_batch(images, labels)# print(images.size(), labels)if __name__=="__main__":transform_train = transforms.Compose([transforms.ToPILImage(),transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])train_dataset = CIFAR100Dataset(path='./data/cifar-100-python', transform=transform_train)train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)for images, labels in train_loader:show_batch(images, labels)# print(images.size(), labels)