分享

手写数字识别基本思路

 算法与编程之美 2023-05-03 发布于四川

问题

什么是MNIST?如何使用Pytorch实现手写数字识别?如何进行手写数字对模型进行检验?

方法

mnist数据集

MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集。

使用Pytorch实现手写数字识别

1.进行数据预处理对于MNIST数据集,可以通过torchvision中的datasets进行下载。

root (string):表示数据集的根目录,其中根目录存在MNIST/processed/training.pt和MNIST/processed/test.pt的子目录。

train (bool, optional):如果为True,则从training.pt创建数据集,否则从test.pt创建数据集。

download (bool, optional):如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载。

transform (callable, optional):接收PIL图片并返回转换后版本图片的转换函数。

bat = 128
transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize(0.1307, 0.3081)  # (均值,方差)
])  # Compoes 两个操作合为一个
train_ds = datasets.MNIST(root='data', download=False, train=True,
                         transform=transform)
train_ds, val_ds = torch.utils.data.random_split(train_ds, [50000, 10000])
test_ds = datasets.MNIST(root='data', download=True, train=False,
                        transform=transform)
train_loader = DataLoader(dataset=train_ds, batch_size=bat, shuffle=True)
val_loader = DataLoader(dataset=val_ds, batch_size=bat)
test_loader = DataLoader(dataset=test_ds, batch_size=bat)

2.构建模型

class MyNet(nn.Module):
   def __init__(self) -> None:
       super().__init__()

       self.flatten = nn.Flatten()  # 将28*28的图像拉伸为784维向量
       # 第一个全连接层Full Connection(FC)
       self.fc1 = nn.Linear(in_features=784,
                            out_features=256)
       self.fc2 = nn.Linear(in_features=256,
                            out_features=128)
       self.fc3 = nn.Linear(in_features=128,
                            out_features=10)

   def forward(self, x):
       x = self.flatten(x)
       x = torch.relu(self.fc1(x))
       x = torch.relu(self.fc2(x))
       out = torch.relu(self.fc3(x))
       return out

构建一个三层的神经网络MNIST数据集中的图片都是28×28大小的,而且是灰度图。而全连接神经网络的输入要是一个行向量,所以我们要把28×28的矩阵转换成28×28=764的行向量,作为神经网络的输入

3.优化器的选择,参数设置

使用优化器和损失函数。优化器选择SGD,SGD随机梯度下降,lr学习率取值0.2最优,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生。

optimizer=torch.optim.SGD(net.parameters(),lr=0.2)#lr学习率,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生
#损失函数
#衡量yy_hat之间的差异
loss_fn=nn.CrossEntropyLoss()

4.对模型进行训练测试,网络的输入,输入尺寸B*C*H*W B是batch,一个batch一个batch交给网络处理,x=torch.rand(size=(128,1,28,28)),基于loss信息利用优化器从后向前更新网络全部参数。

def train(dataloader, net, loss_fn, optimizer, epoch):
   size = len(dataloader.dataset)
   corrent = 0
   epoch_loss = 0.0
   batch_num = len(dataloader)
   net.train()

   # 一个batch一个batch的训练网络
   for batch_idx, (X, y) in enumerate(dataloader):
       pred = net(X)

       # 衡量y与y_hat之间的loss
       # y:128, pred:128x10 CrossEntropyloss
       loss = loss_fn(pred, y)

       # 基于loss信息利用优化器从后向前更新网络全部参数 <---
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       epoch_loss += loss.item()
       corrent += (pred.argmax(1) == y).type(torch.float).sum().item()
       if batch_idx % 100 == 0:
           # f-string
           print(f'[{batch_idx + 1:>5d}/{batch_num + 1:>5d}],loss:{loss.item()}')
   avg_loss = epoch_loss / batch_num
   avg_accuracy = corrent / size
   # loss_list.append(avg_loss)
   return avg_accuracy, avg_loss
def test(dataloader, net, loss_fn):
   size = len(dataloader.dataset)
   batch_num = len(dataloader)
   corrent = 0
   losses = 0
   net.eval()
   with torch.no_grad():
       for X, y in test_loader:
           pred = net(X)
           correct = (pred.argmax(1) == y).type(torch.int).sum().item()
           # print(y.size(0))
           # print(correct)
           corrent += correct
   accuracy = corrent / size
   avg_loss = losses / batch_num
   return accuracy, avg_loss


5.保存最优的模型

net.load_state_dict(torch.load('model_best.pth'))
   test(test_loader,net,loss_fn)

6.读入自己的写入数字,进行识别

model = MyNet()
model.load_state_dict(torch.load('model_best.pth'))
img = Image.open("7.png").convert("L")  # 转为灰度图像
img = transform(img)
# img = np.array(img)
# print(img)
result = model(img)
_, predict = torch.max(result.data, dim=1)
print(result)
print("the result is:",predict.item())

结语

minist是一个28*28的图像,所以输入就是28*28=784的维度,输出为10,0-9十个数字。手写数字识别首先需要初始化全局变量,构建数据集。然后构建模型,构建迭代器与损失函数,进行训练测试。最后可以将训练的模型进行保存,通过读取自己写的数字进行识别验证,完成一个简单的深度学习。

    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多