人妖在线一区,国产日韩欧美一区二区综合在线,国产啪精品视频网站免费,欧美内射深插日本少妇

新聞動(dòng)態(tài)

解決pytorch rnn 變長(zhǎng)輸入序列的問(wèn)題

發(fā)布日期:2022-03-30 15:21 | 文章來(lái)源:腳本之家

pytorch實(shí)現(xiàn)變長(zhǎng)輸入的rnn分類(lèi)

輸入數(shù)據(jù)是長(zhǎng)度不固定的序列數(shù)據(jù),主要講解兩個(gè)部分

1、Data.DataLoader的collate_fn用法,以及按batch進(jìn)行padding數(shù)據(jù)

2、pack_padded_sequence和pad_packed_sequence來(lái)處理變長(zhǎng)序列

collate_fn

Dataloader的collate_fn參數(shù),定義數(shù)據(jù)處理和合并成batch的方式。

由于pack_padded_sequence用到的tensor必須按照長(zhǎng)度從大到小排過(guò)序的,所以在Collate_fn中,需要完成兩件事,一是把當(dāng)前batch的樣本按照當(dāng)前batch最大長(zhǎng)度進(jìn)行padding,二是將padding后的數(shù)據(jù)從大到小進(jìn)行排序。

def pad_tensor(vec, pad):
 """
 args:
  vec - tensor to pad
  pad - the size to pad to
 return:
  a new tensor padded to 'pad'
 """
 return torch.cat([vec, torch.zeros(pad - len(vec), dtype=torch.float)], dim=0).data.numpy()
class Collate:
 """
 a variant of callate_fn that pads according to the longest sequence in
 a batch of sequences
 """
 def __init__(self):
  pass
 def _collate(self, batch):
  """
  args:
batch - list of (tensor, label)
  reutrn:
xs - a tensor of all examples in 'batch' before padding like:
 '''
 [tensor([1,2,3,4]),
  tensor([1,2]),
  tensor([1,2,3,4,5])]
 '''
ys - a LongTensor of all labels in batch like:
 '''
 [1,0,1]
 '''
  """
  xs = [torch.FloatTensor(v[0]) for v in batch]
  ys = torch.LongTensor([v[1] for v in batch])
  # 獲得每個(gè)樣本的序列長(zhǎng)度
  seq_lengths = torch.LongTensor([v for v in map(len, xs)])
  max_len = max([len(v) for v in xs])
  # 每個(gè)樣本都padding到當(dāng)前batch的最大長(zhǎng)度
  xs = torch.FloatTensor([pad_tensor(v, max_len) for v in xs])
  # 把xs和ys按照序列長(zhǎng)度從大到小排序
  seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
  xs = xs[perm_idx]
  ys = ys[perm_idx]
  return xs, seq_lengths, ys
 def __call__(self, batch):
  return self._collate(batch)

定義完collate類(lèi)以后,在DataLoader中直接使用

train_data = Data.DataLoader(dataset=train_dataset, batch_size=32, num_workers=0, collate_fn=Collate())

torch.nn.utils.rnn.pack_padded_sequence()

pack_padded_sequence將一個(gè)填充過(guò)的變長(zhǎng)序列壓緊。輸入?yún)?shù)包括

input(Variable)- 被填充過(guò)后的變長(zhǎng)序列組成的batch data

lengths (list[int]) - 變長(zhǎng)序列的原始序列長(zhǎng)度

batch_first (bool,optional) - 如果是True,input的形狀應(yīng)該是(batch_size,seq_len,input_size)

返回值:一個(gè)PackedSequence對(duì)象,可以直接作為rnn,lstm,gru的傳入數(shù)據(jù)。

用法:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# x是填充過(guò)后的batch數(shù)據(jù),seq_lengths是每個(gè)樣本的序列長(zhǎng)度
packed_input = pack_padded_sequence(x, seq_lengths, batch_first=True)

RNN模型

定義了一個(gè)單向的LSTM模型,因?yàn)樘幚淼氖亲冮L(zhǎng)序列,forward函數(shù)傳入的值是一個(gè)PackedSequence對(duì)象,返回值也是一個(gè)PackedSequence對(duì)象

class Model(nn.Module):
 def __init__(self, in_size, hid_size, n_layer, drop=0.1, bi=False):
  super(Model, self).__init__()
  self.lstm = nn.LSTM(input_size=in_size,
hidden_size=hid_size,
num_layers=n_layer,
batch_first=True,
dropout=drop,
bidirectional=bi)
  # 分類(lèi)類(lèi)別數(shù)目為2
  self.fc = nn.Linear(in_features=hid_size, out_features=2)
 def forward(self, x):
  '''
  :param x: 變長(zhǎng)序列時(shí),x是一個(gè)PackedSequence對(duì)象
  :return: PackedSequence對(duì)象
  '''
  # lstm_out: tensor of shape (batch, seq_len, num_directions * hidden_size)
  lstm_out, _ = self.lstm(x)  
  
  return lstm_out
model = Model()
lstm_out = model(packed_input)

torch.nn.utils.rnn.pad_packed_sequence()

這個(gè)操作和pack_padded_sequence()是相反的,把壓緊的序列再填充回來(lái)。因?yàn)榍懊嫣岬降腖STM模型傳入和返回的都是PackedSequence對(duì)象,所以我們?nèi)绻胍逊祷氐腜ackedSequence對(duì)象轉(zhuǎn)換回Tensor,就需要用到pad_packed_sequence函數(shù)。

參數(shù)說(shuō)明:

sequence (PackedSequence) – 將要被填充的 batch

batch_first (bool, optional) – 如果為T(mén)rue,返回的數(shù)據(jù)的形狀為(batch_size,seq_len,input_size)

返回值: 一個(gè)tuple,包含被填充后的序列,和batch中序列的長(zhǎng)度列表。

用法:

# 此處lstm_out是一個(gè)PackedSequence對(duì)象
output, _ = pad_packed_sequence(lstm_out)

返回的output是一個(gè)形狀為(batch_size,seq_len,input_size)的tensor。

總結(jié)

1、pytorch在自定義dataset時(shí),可以在DataLoader的collate_fn參數(shù)中定義對(duì)數(shù)據(jù)的變換,操作以及合成batch的方式。

2、處理變長(zhǎng)rnn問(wèn)題時(shí),通過(guò)pack_padded_sequence()將填充的batch數(shù)據(jù)轉(zhuǎn)換成PackedSequence對(duì)象,直接傳入rnn模型中。通過(guò)pad_packed_sequence()來(lái)將rnn模型輸出的PackedSequence對(duì)象轉(zhuǎn)換回相應(yīng)的Tensor。

補(bǔ)充:pytorch實(shí)現(xiàn)不定長(zhǎng)輸入的RNN / LSTM / GRU

情景描述

As we all know,RNN循環(huán)神經(jīng)網(wǎng)絡(luò)(及其改進(jìn)模型LSTM、GRU)可以處理序列的順序信息,如人類(lèi)自然語(yǔ)言。但是在實(shí)際場(chǎng)景中,我們常常向模型輸入一個(gè)批次(batch)的數(shù)據(jù),這個(gè)批次中的每個(gè)序列往往不是等長(zhǎng)的。

pytorch提供的模型(nn.RNN,nn.LSTM,nn.GRU)是支持可變長(zhǎng)序列的處理的,但條件是傳入的數(shù)據(jù)必須按序列長(zhǎng)度排序。本文針對(duì)以下兩種場(chǎng)景提出解決方法。

1、每個(gè)樣本只有一個(gè)序列:(seq,label),其中seq是一個(gè)長(zhǎng)度不定的序列。則使用pytorch訓(xùn)練時(shí),我們將按列把一個(gè)批次的數(shù)據(jù)輸入網(wǎng)絡(luò),seq這一列的形狀就是(batch_size, seq_len),經(jīng)過(guò)編碼層(如word2vec)之后的形狀是(batch_size, seq_len, emb_size)。

2、情況1的拓展:每個(gè)樣本有兩個(gè)(或多個(gè))序列,如(seq1, seq2, label)。這種樣本形式在問(wèn)答系統(tǒng)、推薦系統(tǒng)多見(jiàn)。

通用解決方案

定義ImprovedRnn類(lèi)。與nn.RNN,nn.LSTM,nn.GRU相比,除了此兩點(diǎn)【①forward函數(shù)多一個(gè)參數(shù)lengths表示每個(gè)seq的長(zhǎng)度】【②初始化函數(shù)(__init__)第一個(gè)參數(shù)module必須指定三者之一】外,使用方法完全相同。

import torch
from torch import nn
class ImprovedRnn(nn.Module):
 def __init__(self, module, *args, **kwargs):
  assert module in (nn.RNN, nn.LSTM, nn.GRU)
  super().__init__()
  self.module = module(*args, **kwargs)
 def forward(self, input, lengths):  # input shape(batch_size, seq_len, input_size)
  if not hasattr(self, '_flattened'):
self.module.flatten_parameters()
setattr(self, '_flattened', True)
  max_len = input.shape[1]
  # enforce_sorted=False則自動(dòng)按lengths排序,并且返回值package.unsorted_indices可用于恢復(fù)原順序
  package = nn.utils.rnn.pack_padded_sequence(input, lengths.cpu(), batch_first=self.module.batch_first, enforce_sorted=False)
  result, hidden = self.module(package)
  # total_length參數(shù)一般不需要,因?yàn)閘engths列表中一般含最大值。但分布式訓(xùn)練時(shí)是將一個(gè)batch切分了,故一定要有!
  result, lens = nn.utils.rnn.pad_packed_sequence(result, batch_first=self.module.batch_first, total_length=max_len)
  return result[package.unsorted_indices], hidden  # output shape(batch_size, seq_len, rnn_hidden_size)

使用示例:

class TestNet(nn.Module):
 def __init__(self, word_emb, gru_in, gru_out):
  super().__init__()
  self.encode = nn.Embedding.from_pretrained(torch.Tensor(word_emb))
  self.rnn = ImprovedRnn(nn.RNN, input_size=gru_in, hidden_size=gru_out,
		  				batch_first=True, bidirectional=True)
 def forward(self, seq1, seq1_lengths, seq2, seq2_lengths):
  seq1_emb = self.encode(seq1)
  seq2_emb = self.encode(seq2)
  rnn1, hn = self.rnn(seq1_emb, seq1_lengths)
  rnn2, hn = self.rnn(seq2_emb, seq2_lengths)
  """
  此處略去rnn1和rnn2的后續(xù)計(jì)算,當(dāng)前網(wǎng)絡(luò)最后計(jì)算結(jié)果記為prediction
  """
  return prediction

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持本站。

國(guó)外服務(wù)器租用

版權(quán)聲明:本站文章來(lái)源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請(qǐng)保持原文完整并注明來(lái)源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非www.sddonglingsh.com所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來(lái)源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來(lái),僅供學(xué)習(xí)參考,不代表本站立場(chǎng),如有內(nèi)容涉嫌侵權(quán),請(qǐng)聯(lián)系alex-e#qq.com處理。

相關(guān)文章

實(shí)時(shí)開(kāi)通

自選配置、實(shí)時(shí)開(kāi)通

免備案

全球線路精選!

全天候客戶(hù)服務(wù)

7x24全年不間斷在線

專(zhuān)屬顧問(wèn)服務(wù)

1對(duì)1客戶(hù)咨詢(xún)顧問(wèn)

在線
客服

在線客服:7*24小時(shí)在線

客服
熱線

400-630-3752
7*24小時(shí)客服服務(wù)熱線

關(guān)注
微信

關(guān)注官方微信
頂部