PyTorch零基礎(chǔ)入門之邏輯斯蒂回歸
學(xué)習(xí)總結(jié)
(1)和上一講的模型訓(xùn)練是類似的,只是在線性模型的基礎(chǔ)上加個(gè)sigmoid,然后loss函數(shù)改為交叉熵BCE函數(shù)(當(dāng)然也可以用其他函數(shù)),另外一開始的數(shù)據(jù)y_data也從數(shù)值改為類別0和1(本例為二分類,注意x_data
和y_data
這里也是矩陣的形式)。
一、sigmoid函數(shù)
logistic function是一種sigmoid函數(shù)(還有其他sigmoid函數(shù)),但由于使用過于廣泛,pytorch默認(rèn)logistic function叫為sigmoid函數(shù)。還有如下的各種sigmoid函數(shù):
二、和Linear的區(qū)別
邏輯斯蒂和線性模型的unit區(qū)別如下圖:
sigmoid
函數(shù)是不需要參數(shù)的,所以不用對(duì)其初始化(直接調(diào)用nn.functional.sigmoid
即可)。
另外loss函數(shù)從MSE改用交叉熵BCE:盡可能和真實(shí)分類貼近。
如下圖右方表格所示,當(dāng) y ^ \hat{y} y^越接近y時(shí)則BCE Loss值越小。
三、邏輯斯蒂回歸(分類)PyTorch實(shí)現(xiàn)
# -*- coding: utf-8 -*- """ Created on Mon Oct 18 08:35:00 2021 @author: 86493 """ import torch import torch.nn as nn import matplotlib.pyplot as plt import torch.nn.functional as F import numpy as np # 準(zhǔn)備數(shù)據(jù) x_data = torch.Tensor([[1.0], [2.0], [3.0]]) y_data = torch.Tensor([[0], [0], [1]]) losslst = [] class LogisticRegressionModel(nn.Module): def __init__(self): super(LogisticRegressionModel, self).__init__() self.linear = torch.nn.Linear(1, 1) def forward(self, x): # 和線性模型的網(wǎng)絡(luò)的唯一區(qū)別在這句,多了F.sigmoid y_predict = F.sigmoid(self.linear(x)) return y_predict model = LogisticRegressionModel() # 使用交叉熵作損失函數(shù) criterion = torch.nn.BCELoss(size_average = False) optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # 訓(xùn)練 for epoch in range(1000): y_predict = model(x_data) loss = criterion(y_predict, y_data) # 打印loss對(duì)象會(huì)自動(dòng)調(diào)用__str__ print(epoch, loss.item()) losslst.append(loss.item()) # 梯度清零后反向傳播 optimizer.zero_grad() loss.backward() optimizer.step() # 畫圖 plt.plot(range(1000), losslst) plt.ylabel('Loss') plt.xlabel('epoch') plt.show() # test # 每周學(xué)習(xí)的時(shí)間,200個(gè)點(diǎn) x = np.linspace(0, 10, 200) x_t = torch.Tensor(x).view((200, 1)) y_t = model(x_t) y = y_t.data.numpy() plt.plot(x, y) # 畫 probability of pass = 0.5的紅色橫線 plt.plot([0, 10], [0.5, 0.5], c = 'r') plt.xlabel('Hours') plt.ylabel('Probability of Pass') plt.grid() plt.show()
可以看出處于通過和不通過的分界線是Hours=2.5。
Reference
pytorch官方文檔
到此這篇關(guān)于PyTorch零基礎(chǔ)入門之邏輯斯蒂回歸的文章就介紹到這了,更多相關(guān)PyTorch 邏輯斯蒂回歸內(nèi)容請(qǐng)搜索本站以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持本站!
版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請(qǐng)保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非www.sddonglingsh.com所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場(chǎng),如有內(nèi)容涉嫌侵權(quán),請(qǐng)聯(lián)系alex-e#qq.com處理。