人妖在线一区,国产日韩欧美一区二区综合在线,国产啪精品视频网站免费,欧美内射深插日本少妇

新聞動(dòng)態(tài)

PyTorch零基礎(chǔ)入門之邏輯斯蒂回歸

發(fā)布日期:2021-12-23 13:51 | 文章來源:gibhub

學(xué)習(xí)總結(jié)

(1)和上一講的模型訓(xùn)練是類似的,只是在線性模型的基礎(chǔ)上加個(gè)sigmoid,然后loss函數(shù)改為交叉熵BCE函數(shù)(當(dāng)然也可以用其他函數(shù)),另外一開始的數(shù)據(jù)y_data也從數(shù)值改為類別0和1(本例為二分類,注意x_datay_data這里也是矩陣的形式)。

一、sigmoid函數(shù)

logistic function是一種sigmoid函數(shù)(還有其他sigmoid函數(shù)),但由于使用過于廣泛,pytorch默認(rèn)logistic function叫為sigmoid函數(shù)。還有如下的各種sigmoid函數(shù):

二、和Linear的區(qū)別

邏輯斯蒂和線性模型的unit區(qū)別如下圖:

sigmoid函數(shù)是不需要參數(shù)的,所以不用對(duì)其初始化(直接調(diào)用nn.functional.sigmoid即可)。
另外loss函數(shù)從MSE改用交叉熵BCE:盡可能和真實(shí)分類貼近。

如下圖右方表格所示,當(dāng) y ^ \hat{y} y^​越接近y時(shí)則BCE Loss值越小。

三、邏輯斯蒂回歸(分類)PyTorch實(shí)現(xiàn)

# -*- coding: utf-8 -*-
"""
Created on Mon Oct 18 08:35:00 2021
@author: 86493
"""
import torch
import torch.nn as nn
import matplotlib.pyplot as plt  
import torch.nn.functional as F
import numpy as np
# 準(zhǔn)備數(shù)據(jù)
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])

losslst = []
class LogisticRegressionModel(nn.Module):
 def __init__(self):
  super(LogisticRegressionModel, self).__init__()
  self.linear = torch.nn.Linear(1, 1)
  
 def forward(self, x):
 	# 和線性模型的網(wǎng)絡(luò)的唯一區(qū)別在這句,多了F.sigmoid
  y_predict = F.sigmoid(self.linear(x))
  return y_predict
 
model = LogisticRegressionModel()
# 使用交叉熵作損失函數(shù)
criterion = torch.nn.BCELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(), 
lr = 0.01)
# 訓(xùn)練
for epoch in range(1000):
 y_predict = model(x_data)
 loss = criterion(y_predict, y_data)
 # 打印loss對(duì)象會(huì)自動(dòng)調(diào)用__str__
 print(epoch, loss.item())
 losslst.append(loss.item())
 # 梯度清零后反向傳播
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
# 畫圖
plt.plot(range(1000), losslst)
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.show()

# test
# 每周學(xué)習(xí)的時(shí)間,200個(gè)點(diǎn)
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
# 畫 probability of pass = 0.5的紅色橫線
plt.plot([0, 10], [0.5, 0.5], c = 'r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

可以看出處于通過和不通過的分界線是Hours=2.5。

Reference

pytorch官方文檔

到此這篇關(guān)于PyTorch零基礎(chǔ)入門之邏輯斯蒂回歸的文章就介紹到這了,更多相關(guān)PyTorch 邏輯斯蒂回歸內(nèi)容請(qǐng)搜索本站以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持本站!

版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請(qǐng)保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非www.sddonglingsh.com所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場(chǎng),如有內(nèi)容涉嫌侵權(quán),請(qǐng)聯(lián)系alex-e#qq.com處理。

相關(guān)文章

實(shí)時(shí)開通

自選配置、實(shí)時(shí)開通

免備案

全球線路精選!

全天候客戶服務(wù)

7x24全年不間斷在線

專屬顧問服務(wù)

1對(duì)1客戶咨詢顧問

在線
客服

在線客服:7*24小時(shí)在線

客服
熱線

400-630-3752
7*24小時(shí)客服服務(wù)熱線

關(guān)注
微信

關(guān)注官方微信
頂部