人妖在线一区,国产日韩欧美一区二区综合在线,国产啪精品视频网站免费,欧美内射深插日本少妇

新聞動態(tài)

pytorch實現(xiàn)手寫數(shù)字圖片識別

發(fā)布日期:2022-06-01 14:29 | 文章來源:腳本之家

本文實例為大家分享了pytorch實現(xiàn)手寫數(shù)字圖片識別的具體代碼,供大家參考,具體內(nèi)容如下

數(shù)據(jù)集:MNIST數(shù)據(jù)集,代碼中會自動下載,不用自己手動下載。數(shù)據(jù)集很小,不需要GPU設(shè)備,可以很好的體會到pytorch的魅力。
模型+訓(xùn)練+預(yù)測程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot
# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
 torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize(
  (0.1307,), (0.3081,)
 )
])),
 batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
 torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
 torchvision.transforms.ToTensor(),
 torchvision.transforms.Normalize(
  (0.1307,), (0.3081,)
 )
])),
 batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")
class Net(nn.Module):
 def __init__(self):
  super(Net, self).__init__()
  self.fc1 = nn.Linear(28*28, 256)
  self.fc2 = nn.Linear(256, 64)
  self.fc3 = nn.Linear(64, 10)
 def forward(self, x):
  # x: [b, 1, 28, 28]
  # h1 = relu(xw1 + b1)
  x = F.relu(self.fc1(x))
  # h2 = relu(h1w2 + b2)
  x = F.relu(self.fc2(x))
  # h3 = h2w3 + b3
  x = self.fc3(x)
  return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []
for epoch in range(3):
 for batch_idx, (x, y) in enumerate(train_loader):
  #加載進來的圖片是一個四維的tensor,x: [b, 1, 28, 28], y:[512]
  #但是我們網(wǎng)絡(luò)的輸入要是一個一維向量(也就是二維tensor),所以要進行展平操作
  x = x.view(x.size(0), 28*28)
  #  [b, 10]
  out = net(x)
  y_onehot = one_hot(y)
  # loss = mse(out, y_onehot)
  loss = F.mse_loss(out, y_onehot)
  optimizer.zero_grad()
  loss.backward()
  # w' = w - lr*grad
  optimizer.step()
  train_loss.append(loss.item())
  if batch_idx % 10 == 0:
print(epoch, batch_idx, loss.item())
plot_curve(train_loss)
 # we get optimal [w1, b1, w2, b2, w3, b3]

total_correct = 0
for x,y in test_loader:
 x = x.view(x.size(0), 28*28)
 out = net(x)
 # out: [b, 10]
 pred = out.argmax(dim=1)
 correct = pred.eq(y).sum().float().item()
 total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中調(diào)用的函數(shù)(注意命名為utils):

import  torch
from matplotlib import pyplot as plt

def plot_curve(data):
 fig = plt.figure()
 plt.plot(range(len(data)), data, color='blue')
 plt.legend(['value'], loc='upper right')
 plt.xlabel('step')
 plt.ylabel('value')
 plt.show()

def plot_image(img, label, name):
 fig = plt.figure()
 for i in range(6):
  plt.subplot(2, 3, i + 1)
  plt.tight_layout()
  plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
  plt.title("{}: {}".format(name, label[i].item()))
  plt.xticks([])
  plt.yticks([])
 plt.show()

def one_hot(label, depth=10):
 out = torch.zeros(label.size(0), depth)
 idx = torch.LongTensor(label).view(-1, 1)
 out.scatter_(dim=1, index=idx, value=1)
 return out

打印出損失下降的曲線圖:

訓(xùn)練3個epoch之后,在測試集上的精度就可以89%左右,可見模型的準(zhǔn)確度還是很不錯的。
輸出六張測試集的圖片以及預(yù)測結(jié)果:

六張圖片的預(yù)測全部正確。

以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持本站。

海外穩(wěn)定服務(wù)器

版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非www.sddonglingsh.com所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場,如有內(nèi)容涉嫌侵權(quán),請聯(lián)系alex-e#qq.com處理。

相關(guān)文章

實時開通

自選配置、實時開通

免備案

全球線路精選!

全天候客戶服務(wù)

7x24全年不間斷在線

專屬顧問服務(wù)

1對1客戶咨詢顧問

在線
客服

在線客服:7*24小時在線

客服
熱線

400-630-3752
7*24小時客服服務(wù)熱線

關(guān)注
微信

關(guān)注官方微信
頂部