當(dāng)我們?cè)?/font>10.7 節(jié)遇到機(jī)器翻譯時(shí),我們?cè)O(shè)計(jì)了一個(gè)基于兩個(gè) RNN 的序列到序列 (seq2seq) 學(xué)習(xí)的編碼器-解碼器架構(gòu) ( Sutskever et al. , 2014 )。具體來(lái)說(shuō),RNN 編碼器將可變長(zhǎng)度序列轉(zhuǎn)換為固定形狀的上下文變量。然后,RNN 解碼器根據(jù)生成的標(biāo)記和上下文變量逐個(gè)標(biāo)記地生成輸出(目標(biāo))序列標(biāo)記。
回想一下我們?cè)谙旅嬷赜〉?/font>圖 10.7.2 (圖 11.4.1)以及一些額外的細(xì)節(jié)。通常,在 RNN 中,有關(guān)源序列的所有相關(guān)信息都由編碼器轉(zhuǎn)換為某種內(nèi)部固定維狀態(tài)表示。正是這種狀態(tài)被解碼器用作生成翻譯序列的完整和唯一的信息源。換句話(huà)說(shuō),seq2seq 機(jī)制將中間狀態(tài)視為可能作為輸入的任何字符串的充分統(tǒng)計(jì)。
圖 11.4.1序列到序列模型。編碼器生成的狀態(tài)是編碼器和解碼器之間唯一共享的信息。
雖然這對(duì)于短序列來(lái)說(shuō)是相當(dāng)合理的,但很明顯這對(duì)于長(zhǎng)序列來(lái)說(shuō)是不可行的,比如一本書(shū)的章節(jié),甚至只是一個(gè)很長(zhǎng)的句子。畢竟,一段時(shí)間后,中間表示中將根本沒(méi)有足夠的“空間”來(lái)存儲(chǔ)源序列中所有重要的內(nèi)容。因此,解碼器將無(wú)法翻譯又長(zhǎng)又復(fù)雜的句子。第一個(gè)遇到的人是 格雷夫斯 ( 2013 )當(dāng)他們?cè)噲D設(shè)計(jì)一個(gè) RNN 來(lái)生成手寫(xiě)文本時(shí)。由于源文本具有任意長(zhǎng)度,他們?cè)O(shè)計(jì)了一個(gè)可區(qū)分的注意力模型來(lái)將文本字符與更長(zhǎng)的筆跡對(duì)齊,其中對(duì)齊僅在一個(gè)方向上移動(dòng)。這反過(guò)來(lái)又利用了語(yǔ)音識(shí)別中的解碼算法,例如隱馬爾可夫模型 (Rabiner 和 Juang,1993 年)。
受到學(xué)??習(xí)對(duì)齊的想法的啟發(fā), Bahdanau等人。( 2014 )提出了一種沒(méi)有單向?qū)R限制的可區(qū)分注意力模型。在預(yù)測(cè)標(biāo)記時(shí),如果并非所有輸入標(biāo)記都相關(guān),則模型僅對(duì)齊(或關(guān)注)輸入序列中被認(rèn)為與當(dāng)前預(yù)測(cè)相關(guān)的部分。然后,這用于在生成下一個(gè)令牌之前更新當(dāng)前狀態(tài)。雖然在其描述中相當(dāng)無(wú)傷大雅,但這種Bahdanau 注意力機(jī)制可以說(shuō)已經(jīng)成為過(guò)去十年深度學(xué)習(xí)中最有影響力的想法之一,并催生了 Transformers (Vaswani等人,2017 年)以及許多相關(guān)的新架構(gòu)。
import tensorflow as tf
from d2l import tensorflow as d2l
11.4.1。模型
我們遵循第 10.7 節(jié)的 seq2seq 架構(gòu)引入的符號(hào) ,特別是(10.7.3)。關(guān)鍵思想是,而不是保持狀態(tài),即上下文變量c將源句子總結(jié)為固定的,我們動(dòng)態(tài)更新它,作為原始文本(編碼器隱藏狀態(tài))的函數(shù)ht) 和已經(jīng)生成的文本(解碼器隱藏狀態(tài)st′?1). 這產(chǎn)生 ct′, 在任何解碼時(shí)間步后更新 t′. 假設(shè)輸入序列的長(zhǎng)度T. 在這種情況下,上下文變量是注意力池的輸出:
我們用了st′?1作為查詢(xún),和 ht作為鍵和值。注意 ct′然后用于生成狀態(tài) st′并生成一個(gè)新令牌(參見(jiàn) (10.7.3))。特別是注意力權(quán)重 α使用由 ( 11.3.7 )定義的附加注意評(píng)分函數(shù)按照 (11.3.3)計(jì)算。這種使用注意力的 RNN 編碼器-解碼器架構(gòu)如圖 11.4.2所示。請(qǐng)注意,后來(lái)對(duì)該模型進(jìn)行了修改,例如在解碼器中包含已經(jīng)生成的標(biāo)記作為進(jìn)一步的上下文(即,注意力總和確實(shí)停止在T而是它繼續(xù)進(jìn)行t′?1). 例如,參見(jiàn)Chan等人。( 2015 )描述了這種應(yīng)用于語(yǔ)音識(shí)別的策略。
圖 11.4.2具有 Bahdanau 注意機(jī)制的 RNN 編碼器-解碼器模型中的層。
11.4.2。用注意力定義解碼器
要實(shí)現(xiàn)帶有注意力的 RNN 編碼器-解碼器,我們只需要重新定義解碼器(從注意力函數(shù)中省略生成的符號(hào)可以簡(jiǎn)化設(shè)計(jì))。讓我們通過(guò)定義一個(gè)意料之中的命名類(lèi)來(lái)開(kāi)始具有注意力的解碼器的基本接口 AttentionDecoder
。
我們需要在Seq2SeqAttentionDecoder
類(lèi)中實(shí)現(xiàn) RNN 解碼器。解碼器的狀態(tài)初始化為(i)編碼器最后一層在所有時(shí)間步的隱藏狀態(tài),用作注意力的鍵和值;(ii) 編碼器在最后一步的所有層的隱藏狀態(tài)。這用于初始化解碼器的隱藏狀態(tài);(iii) 編碼器的有效長(zhǎng)度,以排除注意力池中的填充標(biāo)記。在每個(gè)解碼時(shí)間步,解碼器最后一層的隱藏狀態(tài),在前一個(gè)時(shí)間步獲得,用作注意機(jī)制的查詢(xún)。注意機(jī)制的輸出和輸入嵌入都被連接起來(lái)作為 RNN 解碼器的輸入。
class Seq2SeqAttentionDecoder(AttentionDecoder):
def __init__(self, vocab_size, embed_size<
評(píng)論