分享

怎么调用pytorch中mnist数据集

 算法与编程之美 2023-04-13 发布于四川

问题

怎么调用pytorch中mnist数据集

方法

MNIST数据集介绍

MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集主要包括四个文件,训练集train一共包含了 60000 张图像和标签,而测试集一共包含了 10000 张图像和标签。

idx3表示3维,ubyte表示是以字节的形式进行存储的,t10k表示10000张测试图片(test10000),每张图片是一个28*28像素点的0 ~ 9的灰质手写数字图片,黑底白字,图像像素值为0 ~ 255,越大该点越白。

数据下载和读取

导入PyTorch的两个核心库torch和torchvision,这两个库基本包含了PyTorch会用到的许多方法和函数,其他库为下面所需要的一些辅助库。

import gzip

import os

import torch

import torchvision

import numpy as np

from PIL import Image

from matplotlib import pyplot as plt

from torchvision import datasets, transforms

from torch.utils.data import DataLoader, Dataset

import datasets是为了方便自动下载数据集,可以下载多种数据集,如MNIST、ImageNet、CIFAR10等。

import transforms是pytorch中的图像预处理库,一般用Compose把多个步骤整合到一起。相关详情见:transforms.Compose()函数

使用Pytorch自带的库函数

导入MNIST数据集代码:

train_data = datasets.MNIST(

           root="./data/",

           train=True,

           transform=transforms.To

通过重构Dataset类读取特定的MNIST数据或者制作自己的MNIST数据集

① 读取MNIST文件夹下processed文件中的training.pt、test.pt数据集

class Data_Loader(Dataset):

   def __init__(self, root, transform=None):

       self.data, self.targets = torch.load(root)#采用torch.load进行读取,读取之后的结果为torch.Tensor形式

self.transform = transform

   def __getitem__(self, index):

       img, target = self.data[index], int(self.targets[index])

       img = Image.fromarray(img.numpy(), mode='L')

       if self.transform is not None:

           img = self.transform(img)

       img = transforms.ToTensor()(img)

       return img, target

   def __len__(self):

       return len(self.data)

接下来,调用我们自定义的Data_Loader类来读取数据集:

# root 为training.pt、test.pt文件所在的绝对路径

train_data = Data_Loader(root='./mnist/MNIST/processed/training.pt', transform= None)

test_data = Data_Loader(root='./mnist/MNIST/processed/test.pt', transform= None)

再使用torch.utils.data.DataLoader对train_data和test_data进行加载,展示。

② 读取MNIST文件夹下raw文件中的数据集

class Data_Loader(Dataset):

   def __init__(self, folder, data_name, label_name, transform=None):

       (train_set, train_labels) = load_data(folder, data_name, label_name)

       self.train_set = train_set

       self.train_labels = train_labels

       self.transform = transform

   def __getitem__(self, index):

       img, target = self.train_set[index], int(self.train_labels[index])

       if self.transform is not None:

           img = self.transform(img)

       return img, target

   def __len__(self):

       return len(self.train_set)

def load_data(data_folder, data_name, label_name):

   with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath:  # rb表示的是读取二进制数据

       y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

   with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:

       x_train = np.frombuffer(

           imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

   return x_train, y_train

接下来,调用我们自定义的Data_Loader类来读取数据集:

#folder:MNIST数据集中raw文件的绝对路径

# 读取MNIST数据集中的训练集

train_data = Data_Loader('./MNIST/MNIST/raw', "train-images-idx3-ubyte.gz",

                          "train-labels-idx1-ubyte.gz", transform=transforms.ToTensor())

# 读取MNIST数据集中的测试集

test_data = Data_Loader('./MNIST/MNIST/raw', "t10k-images-idx3-ubyte.gz",

                          "t10k-labels-idx1-ubyte.gz", transform=transforms.ToTensor())

再使用torch.utils.data.DataLoader对train_data和test_data进行加载,展示。

③ 直接读取MNIST数据集

总结

mnist数据集是一个计算机视觉数据集,训练集包括六万张图片,测试集一万张图片,并且已经进行过预处理和格式化。这些数据集有两个功能:一个功能是提供了大量的数据作为训练集和验证集,为一些学习人员提供了丰富的样 本信息一一这一点很宝贵,要知道在深度学习领域要想在一个方面有比较深的研究成果, 除了需要具备一定的网络设计和调优能力以外,还有一个就是要有丰富的训练样本。另一 个功能就是可以形成一个在业内相对有普适性的 Benchmark 比对项目一一既然大家用的数 据集都是一样的,那么每个人设计出来的网络就可以在这些数据集上不断互相比较,从而 验证谁家的网络设计得识别率更高。

    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多