python深度學(xué)習(xí)TensorFlow神經(jīng)網(wǎng)絡(luò)模型的保存和讀取
之前的筆記里實(shí)現(xiàn)了softmax回歸分類、簡單的含有一個(gè)隱層的神經(jīng)網(wǎng)絡(luò)、卷積神經(jīng)網(wǎng)絡(luò)等等,但是這些代碼在訓(xùn)練完成之后就直接退出了,并沒有將訓(xùn)練得到的模型保存下來方便下次直接使用。為了讓訓(xùn)練結(jié)果可以復(fù)用,需要將訓(xùn)練好的神經(jīng)網(wǎng)絡(luò)模型持久化,這就是這篇筆記里要寫的東西。
TensorFlow提供了一個(gè)非常簡單的API,即tf.train.Saver
類來保存和還原一個(gè)神經(jīng)網(wǎng)絡(luò)模型。
下面代碼給出了保存TensorFlow模型的方法:
import tensorflow as tf # 聲明兩個(gè)變量 v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") init_op = tf.global_variables_initializer() # 初始化全部變量 saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 聲明tf.train.Saver類用于保存模型 with tf.Session() as sess: sess.run(init_op) print("v1:", sess.run(v1)) # 打印v1、v2的值一會讀取之后對比 print("v2:", sess.run(v2)) saver_path = saver.save(sess, "save/model.ckpt") # 將模型保存到save/model.ckpt文件 print("Model saved in file:", saver_path)
注:Saver方法已經(jīng)發(fā)生了更改,現(xiàn)在是V2版本,tf.train.Saver(write_version=tf.train.SaverDef.V1)括號里加入該參數(shù)可繼續(xù)使用V1,但會報(bào)warning,可忽略。若使用saver = tf.train.Saver()則默認(rèn)使用當(dāng)前的版本(V2),保存后在save這個(gè)文件夾中會出現(xiàn)4個(gè)文件,比V1版多出model.ckpt.data-00000-of-00001
這個(gè)文件,這點(diǎn)感謝評論里那位朋友指出。至于這個(gè)文件的含義到目前我仍不是很清楚,也沒查到具體資料,TensorFlow15年底開源到現(xiàn)在很多類啊函數(shù)都一直發(fā)生著變動,或被更新或被棄用,可能一些代碼在當(dāng)時(shí)是沒問題的,但過了一大段時(shí)間后再跑可能就會報(bào)錯(cuò),在此注明事件時(shí)間:2017.4.30
這段代碼中,通過saver.save
函數(shù)將TensorFlow模型保存到了save/model.ckpt文件中,這里代碼中指定路徑為"save/model.ckpt"
,也就是保存到了當(dāng)前程序所在文件夾里面的save
文件夾中。
TensorFlow模型會保存在后綴為.ckpt
的文件中。保存后在save這個(gè)文件夾中會出現(xiàn)3個(gè)文件,因?yàn)門ensorFlow會將計(jì)算圖的結(jié)構(gòu)和圖上參數(shù)取值分開保存。
checkpoint
文件保存了一個(gè)目錄下所有的模型文件列表,這個(gè)文件是tf.train.Saver
類自動生成且自動維護(hù)的。在 checkpoint文件中維護(hù)了由一個(gè)tf.train.Saver類持久化的所有TensorFlow模型文件的文件名。當(dāng)某個(gè)保存的TensorFlow模型文件被刪除時(shí),這個(gè)模型所對應(yīng)的文件名也會從checkpoint
文件中刪除。checkpoint中內(nèi)容的格式為CheckpointState Protocol Buffer.
model.ckpt.meta
文件保存了TensorFlow計(jì)算圖的結(jié)構(gòu),可以理解為神經(jīng)網(wǎng)絡(luò)的網(wǎng)絡(luò)結(jié)構(gòu)
TensorFlow通過元圖(MetaGraph)來記錄計(jì)算圖中節(jié)點(diǎn)的信息以及運(yùn)行計(jì)算圖中節(jié)點(diǎn)所需要的元數(shù)據(jù)。TensorFlow中元圖是由MetaGraphDef Protocol Buffer定義的。MetaGraphDef 中的內(nèi)容構(gòu)成了TensorFlow持久化時(shí)的第一個(gè)文件。保存MetaGraphDef 信息的文件默認(rèn)以.meta為后綴名,文件model.ckpt.meta中存儲的就是元圖數(shù)據(jù)。
model.ckpt
文件保存了TensorFlow程序中每一個(gè)變量的取值,這個(gè)文件是通過SSTable格式存儲的,可以大致理解為就是一個(gè)(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在這個(gè)文件中存儲的變量列表。列表剩下的每一行保存了一個(gè)變量的片段,變量片段的信息是通過SavedSlice Protocol Buffer定義的。SavedSlice類型中保存了變量的名稱、當(dāng)前片段的信息以及變量取值。TensorFlow提供了tf.train.NewCheckpointReader
類來查看model.ckpt
文件中保存的變量信息。如何使用tf.train.NewCheckpointReader類這里不做說明,自查。
下面代碼給出了加載TensorFlow模型的方法:
可以對比一下v1、v2的值是隨機(jī)初始化的值還是和之前保存的值是一樣的?
import tensorflow as tf # 使用和保存模型代碼中一樣的方式來聲明變量 v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") saver = tf.train.Saver() # 聲明tf.train.Saver類用于保存模型 with tf.Session() as sess: saver.restore(sess, "save/model.ckpt") # 即將固化到硬盤中的Session從保存路徑再讀取出來 print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的進(jìn)行對比 print("v2:", sess.run(v2)) print("Model Restored")
運(yùn)行結(jié)果:
v1: [[ 0.76705766 1.82217288]] v2: [[-0.98012197 1.23697340.5797025 ] [ 2.50458145 0.81897354 0.07858191]] Model Restored
這段加載模型的代碼基本上和保存模型的代碼是一樣的。也是先定義了TensorFlow計(jì)算圖上所有的運(yùn)算,并聲明了一個(gè)tf.train.Saver
類。兩段唯一的不同是,在加載模型的代碼中沒有運(yùn)行變量的初始化過程,而是將變量的值通過已經(jīng)保存的模型加載進(jìn)來。
也就是說使用TensorFlow完成了一次模型的保存和讀取的操作。
如果不希望重復(fù)定義圖上的運(yùn)算,也可以直接加載已經(jīng)持久化的圖:
import tensorflow as tf # 在下面的代碼中,默認(rèn)加載了TensorFlow計(jì)算圖上定義的全部變量 # 直接加載持久化的圖 saver = tf.train.import_meta_graph("save/model.ckpt.meta") with tf.Session() as sess: saver.restore(sess, "save/model.ckpt") # 通過張量的名稱來獲取張量 print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
運(yùn)行程序,輸出:
[[ 0.76705766 1.82217288]]
有時(shí)可能只需要保存或者加載部分變量。
比如,可能有一個(gè)之前訓(xùn)練好的5層神經(jīng)網(wǎng)絡(luò)模型,但現(xiàn)在想寫一個(gè)6層的神經(jīng)網(wǎng)絡(luò),那么可以將之前5層神經(jīng)網(wǎng)絡(luò)中的參數(shù)直接加載到新的模型,而僅僅將最后一層神經(jīng)網(wǎng)絡(luò)重新訓(xùn)練。
為了保存或者加載部分變量,在聲明tf.train.Saver
類時(shí)可以提供一個(gè)列表來指定需要保存或者加載的變量。比如在加載模型的代碼中使用saver = tf.train.Saver([v1])
命令來構(gòu)建tf.train.Saver類,那么只有變量v1會被加載進(jìn)來。
以上就是python深度學(xué)習(xí)TensorFlow神經(jīng)網(wǎng)絡(luò)模型的保存和讀取的詳細(xì)內(nèi)容,更多關(guān)于TensorFlow網(wǎng)絡(luò)模型保存和讀取的資料請關(guān)注本站其它相關(guān)文章!
版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非www.sddonglingsh.com所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場,如有內(nèi)容涉嫌侵權(quán),請聯(lián)系alex-e#qq.com處理。