人妖在线一区,国产日韩欧美一区二区综合在线,国产啪精品视频网站免费,欧美内射深插日本少妇

新聞動態(tài)

解決pytorch中的kl divergence計算問題

發(fā)布日期:2022-04-14 16:17 | 文章來源:gibhub

偶然從pytorch討論論壇中看到的一個問題,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中計算結果不同,平時沒有注意到,記錄下

一篇關于KL散度、JS散度以及交叉熵對比的文章

kl divergence 介紹

KL散度( Kullback–Leibler divergence),又稱相對熵,是描述兩個概率分布 P 和 Q 差異的一種方法。計算公式:

可以發(fā)現(xiàn),P 和 Q 中元素的個數不用相等,只需要兩個分布中的離散元素一致。

舉個簡單例子:

兩個離散分布分布分別為 P 和 Q

P 的分布為:{1,1,2,2,3}

Q 的分布為:{1,1,1,1,1,2,3,3,3,3}

我們發(fā)現(xiàn),雖然兩個分布中元素個數不相同,P 的元素個數為 5,Q 的元素個數為 10。但里面的元素都有 “1”,“2”,“3” 這三個元素。

當 x = 1時,在 P 分布中,“1” 這個元素的個數為 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 這個元素的個數為 5,故 Q(x = 1) = 5/10 = 0.5

同理,

當 x = 2 時,P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1

當 x = 3 時,P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4

把上述概率帶入公式:

至此,就計算完成了兩個離散變量分布的KL散度。

pytorch 中的 kl_div 函數

pytorch中有用于計算kl散度的函數 kl_div

torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')

計算 D (p||q)

1、不用這個函數的計算結果為:

與手算結果相同

2、使用函數:

(這是計算正確的,結果有差異是因為pytorch這個函數中默認的是以e為底)

注意:

1、函數中的 p q 位置相反(也就是想要計算D(p||q),要寫成kl_div(q.log(),p)的形式),而且q要先取 log

2、reduction 是選擇對各部分結果做什么操作,默認為取平均數,這里選擇求和

好別扭的用法,不知道為啥官方把它設計成這樣

補充:pytorch 的KL divergence的實現(xiàn)

看代碼吧~

import torch.nn.functional as F
# p_logit: [batch, class_num]
# q_logit: [batch, class_num]
def kl_categorical(p_logit, q_logit):
 p = F.softmax(p_logit, dim=-1)
 _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1)
- F.log_softmax(q_logit, dim=-1)), 1)
 return torch.mean(_kl)

以上為個人經驗,希望能給大家一個參考,也希望大家多多支持本站。

香港穩(wěn)定服務器

版權聲明:本站文章來源標注為YINGSOO的內容版權均為本站所有,歡迎引用、轉載,請保持原文完整并注明來源及原文鏈接。禁止復制或仿造本網站,禁止在非www.sddonglingsh.com所屬的服務器上建立鏡像,否則將依法追究法律責任。本站部分內容來源于網友推薦、互聯(lián)網收集整理而來,僅供學習參考,不代表本站立場,如有內容涉嫌侵權,請聯(lián)系alex-e#qq.com處理。

相關文章

實時開通

自選配置、實時開通

免備案

全球線路精選!

全天候客戶服務

7x24全年不間斷在線

專屬顧問服務

1對1客戶咨詢顧問

在線
客服

在線客服:7*24小時在線

客服
熱線

400-630-3752
7*24小時客服服務熱線

關注
微信

關注官方微信
頂部