Pytorch中的gather使用方法
官方說明
gather可以對一個Tensor進(jìn)行聚合,聲明為:torch.gather(input, dim, index, out=None) → Tensor
一般來說有三個參數(shù):輸入的變量input、指定在某一維上聚合的dim、聚合的使用的索引index,輸出為Tensor類型的結(jié)果(index必須為LongTensor類型)。
#參數(shù)介紹: input (Tensor) – The source tensor dim (int) – The axis along which to index index (LongTensor) – The indices of elements to gather out (Tensor, optional) – Destination tensor #當(dāng)輸入為三維時的計算過程: out[i][j][k] = input[index[i][j][k]][j][k] # dim=0 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2 #樣例: t = torch.Tensor([[1,2],[3,4]]) torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) # 1 1 # 4 3 #[torch.FloatTensor of size 2x2]
實驗
用下面的代碼在二維上做測試,以便更好地理解
t = torch.Tensor([[1,2,3],[4,5,6]]) index_a = torch.LongTensor([[0,0],[0,1]]) index_b = torch.LongTensor([[0,1,1],[1,0,0]]) print(t) print(torch.gather(t,dim=1,index=index_a)) print(torch.gather(t,dim=0,index=index_b))
輸出為:
>>tensor([[1., 2., 3.],
[4., 5., 6.]])
>>tensor([[1., 1.],
[4., 5.]])
>>tensor([[1., 5., 6.],
[4., 2., 3.]])
由于官網(wǎng)給的計算過程不太直觀,下面給出較為直觀的解釋:
對于index_a,dim為1表示在第二個維度上進(jìn)行聚合,索引為列號,[[0,0],[0,1]]表示結(jié)果的第一行取原數(shù)組第一行列號為[0,0]的數(shù),也就是[1,1],結(jié)果的第二行取原數(shù)組第二行列號為[0,1]的數(shù),也就是[4,5],這樣就得到了輸出的結(jié)果[[1,1],[4,5]]。
對于index_b,dim為0表示在第一個維度上進(jìn)行聚合,索引為行號,[[0,1,1],[1,0,0]]表示結(jié)果的第一行第d(d=0,1,2)列取原數(shù)組第d列行號為[0,1,1]的數(shù),也就是[1,5,6],類似的,結(jié)果的第二行第d列取原數(shù)組第d列行號為[1,0,0]的數(shù),也就是[4,2,3],這樣就得到了輸出的結(jié)果[[1,5,6],[4,2,3]]
接下來以index_a為例直接用官網(wǎng)的式子計算一遍加深理解:
output[0,0] = input[0,index[0,0]] #1 = input[0,0] output[0,1] = input[0,index[0,1]] #1 = input[0,0] output[1,0] = input[1,index[1,0]] #4 = input[1,0] output[1,1] = input[1,index[1,1]] #5 = input[1,1]
注
以下兩種寫法得到的結(jié)果是一樣的:
r1 = torch.gather(t,dim=1,index=index_a)
r2 = t.gather(1,index_a)
補充:Pytorch中的torch.gather函數(shù)的個人理解
最近在學(xué)習(xí)pytorch時遇到gather函數(shù),開始沒怎么理解,后來查閱網(wǎng)上相關(guān)資料后大概明白了原理。
gather()函數(shù)
在pytorch中,gather()函數(shù)的作用是將數(shù)據(jù)從input中按index提出,我們看gather函數(shù)的的官方文檔說明如下:
torch.gather(input, dim, index, out=None) → Tensor Gathers values along an axis specified by dim. For a 3-D tensor the output is specified by: out[i][j][k] = input[index[i][j][k]][j][k] # dim=0 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2 Parameters: input (Tensor) – The source tensor dim (int) – The axis along which to index index (LongTensor) – The indices of elements to gather out (Tensor, optional) – Destination tensor Example: >>> t = torch.Tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) 1 1 4 3 [torch.FloatTensor of size 2x2]
可以看出,在gather函數(shù)中我們用到的主要有三個參數(shù):
1)input:輸入
2)dim:維度,常用的為0和1
3)index:索引位置
貼一段代碼舉例說明:
a=t.arange(0,16).view(4,4) print(a) index_1=t.LongTensor([[3,2,1,0]]) b=a.gather(0,index_1) print(b) index_2=t.LongTensor([[0,1,2,3]]).t()#tensor轉(zhuǎn)置操作:(a)T=a.t() c=a.gather(1,index_2) print(c)
輸出如下:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[12, 9, 6, 3]])tensor([[ 0],
[ 5],
[10],
[15]])
在gather中,我們是通過index對input進(jìn)行索引把對應(yīng)的數(shù)據(jù)提取出來的,而dim決定了索引的方式。
在上面的例子中,a是一個4×4矩陣:
1)當(dāng)維度dim=0,索引index_1為[3,2,1,0]時,此時可將a看成1×4的矩陣,通過index_1對a每列進(jìn)行行索引:第一列第四行元素為12,第二列第三行元素為9,第三列第二行元素為6,第四列第一行元素為3,即b=[12,9,6,3];
2)當(dāng)維度dim=1,索引index_2為[0,1,2,3]T時,此時可將a看成4×1的矩陣,通過index_1對a每行進(jìn)行列索引:第一行第一列元素為0,第二行第二列元素為5,第三行第三列元素為10,第四行第四列元素為15,即c=[0,5,10,15]T;
總結(jié)
gather函數(shù)在提取數(shù)據(jù)時主要靠dim和index這兩個參數(shù),dim=1時將input看為n×1階矩陣,index看為k×1階矩陣,取index每行元素對input中每行進(jìn)行列索引(如:index某行為[1,3,0],對應(yīng)的input行元素為[9,8,7,6],提取后的結(jié)果為[8,6,9]);
同理,dim=0時將input看為1×n階矩陣,index看為1×k階矩陣,取index每列元素對input中每列進(jìn)行行索引。
gather函數(shù)提取后的矩陣階數(shù)和對應(yīng)的index階數(shù)相同。
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持本站。
版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非www.sddonglingsh.com所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場,如有內(nèi)容涉嫌侵權(quán),請聯(lián)系alex-e#qq.com處理。