人妖在线一区,国产日韩欧美一区二区综合在线,国产啪精品视频网站免费,欧美内射深插日本少妇

新聞動(dòng)態(tài)

pyTorch深入學(xué)習(xí)梯度和Linear Regression實(shí)現(xiàn)

發(fā)布日期:2021-12-29 09:32 | 文章來源:腳本之家

梯度

PyTorch的數(shù)據(jù)結(jié)構(gòu)是tensor,它有個(gè)屬性叫做requires_grad,設(shè)置為True以后,就開始track在其上的所有操作,前向計(jì)算完成后,可以通過backward來進(jìn)行梯度回傳。
評(píng)估模型的時(shí)候我們并不需要梯度回傳,使用with torch.no_grad() 將不需要梯度的代碼段包裹起來。每個(gè)Tensor都有一個(gè).grad_fn屬性,該屬性即創(chuàng)建該Tensor的Function,直接用構(gòu)造的tensor返回None,否則是生成該tensor的操作。

tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor
#require_grad默認(rèn)是false,下面我們將顯式的開啟
x = torch.tensor([1,2,3],requires_grad=True,dtype=torch.float)

注意只有數(shù)據(jù)類型是浮點(diǎn)型和complex類型才能require梯度,所以這里顯示指定dtype為torch.float32

x = torch.tensor([1,2,3],requires_grad=True,dtype=torch.float32)
> tensor([1.,2.,3.],grad_fn=None)
y = x + 2
> tensor([3.,4.,5.],grad_fn=<AddBackward0>)
z = y * y * 3
> tensor([3.,4.,5.],grad_fn=<MulBackward0>)

像x這種直接創(chuàng)建的,沒有g(shù)rad_fn,被稱為葉子結(jié)點(diǎn)。grad_fn記錄了一個(gè)個(gè)基本操作用來進(jìn)行梯度計(jì)算的。
關(guān)于梯度回傳計(jì)算看下面一個(gè)例子

x = torch.ones((2,2),requires_grad=True)
> tensor([[1.,1.],
> 		[1.,1.]],requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
#out是一個(gè)標(biāo)量,無需指定求偏導(dǎo)的變量
out.backward()
x.grad
> tensor([[4.500,4.500],
> 		  [4.500,4.500]])
#每次計(jì)算梯度前,需要將梯度清零,否則會(huì)累加
x.grad.data.zero_()

值得注意的是只有葉子節(jié)點(diǎn)的梯度在回傳時(shí)才會(huì)被計(jì)算,也就是說,上面的例子中拿不到y(tǒng)和z的grad。
來看一個(gè)中斷求導(dǎo)的例子

x = torch.tensor(1.,requires_grad=True)
y1 = x ** 2
with torch.no_grad()
	y2 = x ** 3
y3 = y1 + y2
y3.backward()
print(x.grad)
> 2

本來梯度應(yīng)該為5的,但是由于y2被with torch.no_grad()包裹,在梯度計(jì)算的時(shí)候不會(huì)被追蹤。

如果我們想要修改某個(gè)tensor的數(shù)值但是又不想被autograd記錄,那么需要使用對(duì)x.data進(jìn)行操作就行這也是一個(gè)張量。

線性回歸(linear regression)

利用線性回歸來預(yù)測一棟房屋的價(jià)格,價(jià)格取決于很多feature,這里簡化問題,假設(shè)價(jià)格只取決于兩個(gè)因素,面積(平方米)和房齡(年)

x1代表面積,x2代表房齡,售出價(jià)格為y

模擬數(shù)據(jù)集

假設(shè)我們的樣本數(shù)量為1000個(gè),每個(gè)數(shù)據(jù)包括兩個(gè)features,則數(shù)據(jù)為1000 * 2的2-d張量,用正太分布來隨機(jī)取值。
labels是房屋的價(jià)格,長度為1000的一維張量。
真實(shí)w和b提前把值定好,然后再取一個(gè)干擾量 δ \delta δ(也用高斯分布取值,用來模擬真實(shí)數(shù)據(jù)集中的偏差)

num_features = 2#兩個(gè)特征
num_examples = 1000 #樣本個(gè)數(shù)
w = torch.normal(0,1,(num_features,1))
b = torch.tensor(4.2)
samples = torch.normal(0,1,(num_examples,num_features))
labels = samples.matmul(w) + b
noise = torch.normal(0,.01,labels.shape)
labels += noise

加載數(shù)據(jù)集

import random
def data_iter(samples,labels,batch_size):
	num_samples = samples.shape[0] #獲得batch軸的長度
	indices = [i for i in range(num_samples)]
	random.shuffle(indices)#將索引數(shù)組原地打亂
	for i in range(0,num_samples,batch_size):
	j = torch.LongTensor(indices[i:min(i+batch_size,num_samples)])
	yield samples.index_select(0,j),labels(0,j)

torch.index_select(dim,index)
dim表示tensor的軸,index是一個(gè)tensor,里面包含的是索引。

定義loss_function

def loss_function(predict,labels):
	loss = (predict,labels)** 2 / 2
	return loss.mean()

定義優(yōu)化器

def loss_function(predict,labels):
	loss = (predict,labels)** 2 / 2
	return loss.mean()

開始訓(xùn)練

w = torch.normal(0.,1.,(num_features,1),requires_grad=True)
b = torch.zero(0.,dtype=torch.float32,requires_grad=True)
batch_size = 100
for epoch in range(10):
	for data, label in data_iter(samples,labels,batch_size):
		predict = data.matmul(w) + b
		loss = loss_function(predict,label)
		loss.backward()
		optimizer([w,b],0.05)
		w.grad.data.zero_()
		b.grad.data.zero_() 

以上就是pyTorch深入學(xué)習(xí)梯度和Linear Regression實(shí)現(xiàn)的詳細(xì)內(nèi)容,更多關(guān)于pyTorch實(shí)現(xiàn)梯度和Linear Regression的資料請(qǐng)關(guān)注本站其它相關(guān)文章!

版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請(qǐng)保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非www.sddonglingsh.com所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場,如有內(nèi)容涉嫌侵權(quán),請(qǐng)聯(lián)系alex-e#qq.com處理。

相關(guān)文章

實(shí)時(shí)開通

自選配置、實(shí)時(shí)開通

免備案

全球線路精選!

全天候客戶服務(wù)

7x24全年不間斷在線

專屬顧問服務(wù)

1對(duì)1客戶咨詢顧問

在線
客服

在線客服:7*24小時(shí)在線

客服
熱線

400-630-3752
7*24小時(shí)客服服務(wù)熱線

關(guān)注
微信

關(guān)注官方微信
頂部