人妖在线一区,国产日韩欧美一区二区综合在线,国产啪精品视频网站免费,欧美内射深插日本少妇

新聞動(dòng)態(tài)

Python 實(shí)現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測(cè)

發(fā)布日期:2021-12-08 14:08 | 文章來源:gibhub

1.LeNet模型訓(xùn)練腳本

整體的訓(xùn)練代碼如下,下面我會(huì)為大家詳細(xì)講解這些代碼的意思

import torch
import torchvision
from torchvision.transforms import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
from pytorch.lenet.model import LeNet
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
transform = transforms.Compose(
 # 將數(shù)據(jù)集轉(zhuǎn)換成tensor形式
 [transforms.ToTensor(),
  # 進(jìn)行標(biāo)準(zhǔn)化,0.5是均值,也是方差,對(duì)應(yīng)三個(gè)維度都是0.5
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
# 下載完整的數(shù)據(jù)集時(shí),download=True,第一個(gè)為保存的路徑,下載完后download要改為False
# 為訓(xùn)練集時(shí),train=True,為測(cè)試集時(shí),train=False
train_set = torchvision.datasets.CIFAR10('./data', train=True,
 download=False, transform=transform)
# 加載訓(xùn)練集,設(shè)置批次大小,是否打亂,number_works是線程數(shù),window不設(shè)置為0會(huì)報(bào)錯(cuò),linux可以設(shè)置非零
train_loader = DataLoader(train_set, batch_size=36,  shuffle=True, num_workers=0)
test_set = torchvision.datasets.CIFAR10('./data', train=False,
download=False, transform=transform)
# 設(shè)置的批次大小一次性將所有測(cè)試集圖片傳進(jìn)去
test_loader = DataLoader(test_set, batch_size=10000, shuffle=False, num_workers=0)
# 迭代測(cè)試集的圖片數(shù)據(jù)和標(biāo)簽值
test_img, test_label = next(iter(test_loader))
# CIFAR10的十個(gè)類別名稱
classes = ('plane', 'car', 'bird', 'cat', 'deer',
  'dog', 'frog', 'horse', 'ship', 'truck')
# # ----------------------------顯示圖片-----------------------------------
# def imshow(img, label):
#  fig = plt.figure()
#  for i in range(len(img)):
#ax = fig.add_subplot(1, len(img), i+1)
#nping = img[i].numpy().transpose([1, 2, 0])
#npimg = (nping * 2 + 0.5)
#plt.imshow(npimg)
#title = '{}'.format(classes[label[i]])
#ax.set_title(title)
#plt.axis('off')
#  plt.show()
# 
# 
# batch_image = test_img[: 5]
# label_img = test_label[: 5]
# imshow(batch_image, label_img)
# # ----------------------------------------------------------------------
net = LeNet()
# 定義損失函數(shù),nn.CrossEntropyLoss()自帶softmax函數(shù),所以模型的最后一層不需要softmax進(jìn)行激活
loss_function = nn.CrossEntropyLoss()
# 定義優(yōu)化器,優(yōu)化網(wǎng)絡(luò)模型所有參數(shù)
optimizer = optim.Adam(net.parameters(), lr=0.001)
# 迭代五次
for epoch in range(5):
 # 初始損失設(shè)置為0
 running_loss = 0
 # 循環(huán)訓(xùn)練集,從1開始
 for step, data in enumerate(train_loader, start=1):
  inputs, labels = data
  # 優(yōu)化器的梯度清零,每次循環(huán)都需要清零,否則梯度會(huì)無限疊加,相當(dāng)于增加批次大小
  optimizer.zero_grad()
  # 將圖片數(shù)據(jù)輸入模型中
  outputs = net(inputs)
  # 傳入預(yù)測(cè)值和真實(shí)值,計(jì)算當(dāng)前損失值
  loss = loss_function(outputs, labels)
  # 損失反向傳播
  loss.backward()
  # 進(jìn)行梯度更新
  optimizer.step()
  # 計(jì)算該輪的總損失,因?yàn)閘oss是tensor類型,所以需要用item()取具體值
  running_loss += loss.item()
  # 每500次進(jìn)行日志的打印,對(duì)測(cè)試集進(jìn)行預(yù)測(cè)
  if step % 500 == 0:
# torch.no_grad()就是上下文管理,測(cè)試時(shí)不需要梯度更新,不跟蹤梯度
with torch.no_grad():
 # 傳入所有測(cè)試集圖片進(jìn)行預(yù)測(cè)
 outputs = net(test_img)
 # torch.max()中dim=1是因?yàn)榻Y(jié)果為(batch, 10)的形式,我們只需要取第二個(gè)維度的最大值
 # max這個(gè)函數(shù)返回[最大值, 最大值索引],我們只需要取索引就行了,所以用[1]
 predict_y = torch.max(outputs, dim=1)[1]
 # (predict_y == test_label)相同返回True,不相等返回False,sum()對(duì)正確率進(jìn)行疊加
 # 因?yàn)橛?jì)算的變量都是tensor,所以需要用item()拿到取值
 accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
 # running_loss/500是計(jì)算每一個(gè)step的loss,即每一步的損失
 print('[%d, %5d] train_loss: %.3ftest_accuracy: %.3f' %
 (epoch+1, step, running_loss/500, accuracy))
 running_loss = 0.0
print('Finished Training!')
save_path = 'lenet.pth'
# 保存模型,字典形式
torch.save(net.state_dict(), save_path)

(1).下載CIFAR10數(shù)據(jù)集

首先要訓(xùn)練一個(gè)網(wǎng)絡(luò)模型,我們需要足夠多的圖片做數(shù)據(jù)集,這里我們用的是torchvision.dataset為我們提供的CIFAR10數(shù)據(jù)集(更多的數(shù)據(jù)集可以去pytorch官網(wǎng)查看pytorch官網(wǎng)提供的數(shù)據(jù)集)

train_set = torchvision.datasets.CIFAR10('./data', train=True,
 download=False, transform=transform)
test_set = torchvision.datasets.CIFAR10('./data', train=False,
download=False, transform=transform)

這部分代碼是下載CIFAR10,第一個(gè)參數(shù)是下載數(shù)據(jù)集后存放的路徑,train=True和False對(duì)應(yīng)下載的訓(xùn)練集和測(cè)試集,transform是對(duì)應(yīng)的圖像增強(qiáng)方式

(2).圖像增強(qiáng)

transform = transforms.Compose(
 # 將數(shù)據(jù)集轉(zhuǎn)換成tensor形式
 [transforms.ToTensor(),
  # 進(jìn)行標(biāo)準(zhǔn)化,0.5是均值,也是方差,對(duì)應(yīng)三個(gè)維度都是0.5
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

這就是簡(jiǎn)單的圖像圖像增強(qiáng),transforms.ToTensor()將數(shù)據(jù)集的所有圖像轉(zhuǎn)換成tensor, transforms.Normalize()是標(biāo)準(zhǔn)化處理,包含兩個(gè)元組對(duì)應(yīng)均值和標(biāo)準(zhǔn)差,每個(gè)元組包含三個(gè)元素對(duì)應(yīng)圖片的三個(gè)維度[channels, height, width],為什么是這樣排序,別問,問就是pytorch要求的,順序不能變,之后會(huì)看到transforms.Normalize([0.485, 0.406, 0.456], [0.229, 0.224, 0.225])這兩組數(shù)據(jù),這是官方給出的均值和標(biāo)準(zhǔn)差,之后標(biāo)準(zhǔn)化的時(shí)候會(huì)經(jīng)常用到

(3).加載數(shù)據(jù)集

# 加載訓(xùn)練集,設(shè)置批次大小,是否打亂,number_works是線程數(shù),window不設(shè)置為0會(huì)報(bào)錯(cuò),linux可以設(shè)置非零
train_loader = DataLoader(dataset=train_set, batch_size=36,  shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_set, batch_size=36, shuffle=False, num_workers=0)

這里只簡(jiǎn)單的設(shè)置的四個(gè)參數(shù)也是比較重要的,第一個(gè)就是需要加載的訓(xùn)練集和測(cè)試集,shuffle=True表示將數(shù)據(jù)集打亂,batch_size表示一次性向設(shè)備放入36張圖片,打包成一個(gè)batch,這時(shí)圖片的shape就會(huì)從[3, 32, 32]----》[36, 3, 32, 32],傳入網(wǎng)絡(luò)模型的shape也必須是[None, channels, height, width],None代表一個(gè)batch多少張圖片,否則就會(huì)報(bào)錯(cuò),number_works是代表線程數(shù),window系統(tǒng)必須設(shè)置為0,否則會(huì)報(bào)錯(cuò),linux系統(tǒng)可以設(shè)置非0數(shù)

(4).顯示部分圖像

def imshow(img, label):
 fig = plt.figure()
 for i in range(len(img)):
  ax = fig.add_subplot(1, len(img), i+1)
  nping = img[i].numpy().transpose([1, 2, 0])
  npimg = (nping * 2 + 0.5)
  plt.imshow(npimg)
  title = '{}'.format(classes[label[i]])
  ax.set_title(title)
  plt.axis('off')
 plt.show()

batch_image = test_img[: 5]
label_img = test_label[: 5]
imshow(batch_image, label_img)

這部分代碼是顯示測(cè)試集當(dāng)中前五張圖片,運(yùn)行后會(huì)顯示5張拼接的圖片

由于這個(gè)數(shù)據(jù)集的圖片都比較小都是32x32的尺寸,有些可能也看的不太清楚,圖中顯示的是真實(shí)標(biāo)簽,注:顯示圖片的代碼可能會(huì)這個(gè)報(bào)警(Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).),警告解決的方法:將圖片數(shù)組轉(zhuǎn)成uint8類型即可,即 plt.imshow(npimg.astype(‘uint8'),但是那樣顯示出來的圖片會(huì)變,所以暫時(shí)可以先不用管。

(5).初始化模型

數(shù)據(jù)圖片處理完了,下面就是我們的正式訓(xùn)練過程

net = LeNet()
# 定義損失函數(shù),nn.CrossEntropyLoss()自帶softmax函數(shù),所以模型的最后一層不需要softmax進(jìn)行激活
loss_function = nn.CrossEntropyLoss()
# 定義優(yōu)化器,優(yōu)化模型所有參數(shù)
optimizer = optim.Adam(net.parameters(), lr=0.001)

首先初始化LeNet網(wǎng)絡(luò),定義交叉熵?fù)p失函數(shù),以及Adam優(yōu)化器,關(guān)于注釋寫的,我們可以ctrl+鼠標(biāo)左鍵查看CrossEntropyLoss(),翻到CrossEntropyLoss類,可以看到注釋寫的這個(gè)標(biāo)準(zhǔn)包含LogSoftmax函數(shù),所以搭建LetNet模型的最后一層沒有使用softmax激活函數(shù)

(6).訓(xùn)練模型及保存模型參數(shù)

for epoch in range(5):
 # 初始損失設(shè)置為0
 running_loss = 0
 # 循環(huán)訓(xùn)練集,從1開始
 for step, data in enumerate(train_loader, start=1):
  inputs, labels = data
  # 優(yōu)化器的梯度清零,每次循環(huán)都需要清零,否則梯度會(huì)無限疊加,相當(dāng)于增加批次大小
  optimizer.zero_grad()
  # 將圖片數(shù)據(jù)輸入模型中得到輸出
  outputs = net(inputs)
  # 傳入預(yù)測(cè)值和真實(shí)值,計(jì)算當(dāng)前損失值
  loss = loss_function(outputs, labels)
  # 損失反向傳播
  loss.backward()
  # 進(jìn)行梯度更新(更新W,b)
  optimizer.step()
  # 計(jì)算該輪的總損失,因?yàn)閘oss是tensor類型,所以需要用item()取到值
  running_loss += loss.item()
  # 每500次進(jìn)行日志的打印,對(duì)測(cè)試集進(jìn)行測(cè)試
  if step % 500 == 0:
# torch.no_grad()就是上下文管理,測(cè)試時(shí)不需要梯度更新,不跟蹤梯度
with torch.no_grad():
 # 傳入所有測(cè)試集圖片進(jìn)行預(yù)測(cè)
 outputs = net(test_img)
 # torch.max()中dim=1是因?yàn)榻Y(jié)果為(batch, 10)的形式,我們只需要取第二個(gè)維度的最大值,第二個(gè)維度是包含十個(gè)類別每個(gè)類別的概率的向量
 # max這個(gè)函數(shù)返回[最大值, 最大值索引],我們只需要取索引就行了,所以用[1]
 predict_y = torch.max(outputs, dim=1)[1]
 # (predict_y == test_label)相同返回True,不相等返回False,sum()對(duì)正確結(jié)果進(jìn)行疊加,最后除測(cè)試集標(biāo)簽的總個(gè)數(shù)
 # 因?yàn)橛?jì)算的變量都是tensor,所以需要用item()拿到取值
 accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
 # running_loss/500是計(jì)算每一個(gè)step的loss,即每一步的損失
 print('[%d, %5d] train_loss: %.3ftest_accuracy: %.3f' %
 (epoch+1, step, running_loss/500, accuracy))
 running_loss = 0.0
 
print('Finished Training!')
save_path = 'lenet.pth'
# 保存模型,字典形式
torch.save(net.state_dict(), save_path)

這段代碼注釋寫的很清楚,大家仔細(xì)看就能看懂,流程不復(fù)雜,多看幾遍就能理解,最后再對(duì)訓(xùn)練好的模型進(jìn)行保存就好了(* ̄︶ ̄)

2.預(yù)測(cè)腳本

上面已經(jīng)訓(xùn)練好了模型,得到了lenet.pth參數(shù)文件,預(yù)測(cè)就很簡(jiǎn)單了,可以去網(wǎng)上隨便找一張數(shù)據(jù)集包含的類別圖片,將模型參數(shù)文件載入模型,通過對(duì)圖像進(jìn)行一點(diǎn)處理,喂入模型即可,下面奉上代碼:

import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from pytorch.lenet.model import LeNet
classes = ('plane', 'car', 'bird', 'cat', 'deer',
  'dog', 'frog', 'horse', 'ship', 'truck')
transforms = transforms.Compose(
 # 對(duì)數(shù)據(jù)圖片調(diào)整大小
 [transforms.Resize([32, 32]),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
net = LeNet()
# 加載預(yù)訓(xùn)練模型
net.load_state_dict(torch.load('lenet.pth'))
# 網(wǎng)上隨便找的貓的圖片
img_path = '../../Photo/cat2.jpg'
img = Image.open(img_path)
# 圖片的處理
img = transforms(img)
# 增加一個(gè)維度,(channels, height, width)------->(batch, channels, height, width),pytorch要求必須輸入這樣的shape
img = torch.unsqueeze(img, dim=0)
with torch.no_grad():
 output = net(img)
 # dim=1,只取[batch, 10]中10個(gè)類別的那個(gè)維度,取預(yù)測(cè)結(jié)果的最大值索引,并轉(zhuǎn)換為numpy類型
 prediction1 = torch.max(output, dim=1)[1].data.numpy()
 # 用softmax()預(yù)測(cè)出一個(gè)概率矩陣
 prediction2 = torch.softmax(output, dim=1)
 # 得到概率最大的值得索引
 prediction2 = np.argmax(prediction2)
# 兩種方式都可以得到最后的結(jié)果
print(classes[int(prediction1)])
print(classes[int(prediction2)])

反正我最后預(yù)測(cè)出來結(jié)果把貓識(shí)別成了狗,還有90.01%的概率,就離譜哈哈哈,但也說明了LeNet這個(gè)網(wǎng)絡(luò)模型確實(shí)很淺,特征提取的不夠深,才會(huì)出現(xiàn)這種。

到此這篇關(guān)于Python 實(shí)現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測(cè)的文章就介紹到這了,更多相關(guān)LeNet網(wǎng)絡(luò)模型訓(xùn)練及預(yù)測(cè)內(nèi)容請(qǐng)搜索本站以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持本站!

版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請(qǐng)保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非www.sddonglingsh.com所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場(chǎng),如有內(nèi)容涉嫌侵權(quán),請(qǐng)聯(lián)系alex-e#qq.com處理。

相關(guān)文章

實(shí)時(shí)開通

自選配置、實(shí)時(shí)開通

免備案

全球線路精選!

全天候客戶服務(wù)

7x24全年不間斷在線

專屬顧問服務(wù)

1對(duì)1客戶咨詢顧問

在線
客服

在線客服:7*24小時(shí)在線

客服
熱線

400-630-3752
7*24小時(shí)客服服務(wù)熱線

關(guān)注
微信

關(guān)注官方微信
頂部