目次
全体像
原論文「Attention Is All You Need」にある図。
以下ではTransfomerのコードと共に構造を説明する。
なお、以降で登場する図や例は、「私は太郎です」を「I am Taro」と英訳する際のTransformerの様子を図示したものである。
ただし、BOSトークンやEOSトークンの話は無しで考えているので注意すること。
また、内積/行列積・ベクトル/行列の書き分けが曖昧になっている箇所があるので適宜読み替えること。
大雑把な全体の構成要素は以下の5つとなっている。
- Token Embedding
- Positional Encoding
- Multi-head Attention
- Add & Norm
- Feed Forward
全体の詳細な図
なお、全体の構造を1つの図に書き込んだら、文字が小さくなりすぎた。
見づらいので細かいところまで見たければ画像をクリックして別タブで見ることを推奨。
Token Embedding
Tokenizerで得られた文章の各トークンを埋め込み表現に変換するための処理。
Transformerでは通常は512次元、BERTだと768次元などの埋め込みベクトルを各トークン毎に得る。
また、図ではトークン長を200トークンとして考えているので、短い文章はPADトークンで埋められている。
import torch
import torch.nn as nn
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super(TokenEmbedding, self).__init__()
# 語彙のサイズと埋め込みの次元数を設定
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
# パディングトークンのIDを0として設定
self.pad_token_id = 0
# トークンの埋め込み行列を作成
self.lookup_table = nn.Embedding(
num_embeddings=self.vocab_size,
embedding_dim=self.embedding_dim,
padding_idx=self.pad_token_id,
)
# 埋め込み行列の重みを初期化
self.lookup_table.weight.data.normal_(0., self.embedding_dim ** -0.5)
def forward(self, input: torch.Tensor) -> torch.Tensor:
# パディングトークン以外の位置を示すマスクテンソルを作成
mask = input.ne(self.pad_token_id).float()
# トークンの埋め込みを取得
embedding = self.lookup_table(input)
# パディングトークンの位置に対応する埋め込みを0にする
embedding *= mask.unsqueeze(-1)
# 埋め込みの次元数をスケーリングする
return embedding * self.embedding_dim ** 0.5
Positional Encoding
各トークンの文書内での位置をToken Embeddingで作成したベクトルに埋め込む。
これによって文章内の単語の順序関係を処理できるようになる。
例えば、「私は犬より猫が好きです」と「私は猫より犬が好きです」は意味が全く違うが、出現するトークンは同じため、Token Embeddingだけではそれぞれの犬・猫の埋め込みが全く同じになってしまう。Positional Encodingを施せば、それぞれの犬・猫の埋め込みは別のベクトルになる。
import torch
import torch.nn as nn
import math
class AddPositionalEncoding(nn.Module):
'''
入力テンソルに対し、位置の情報を付与して返すレイヤー
PE_{pos, 2i} = sin(pos / 10000^{2i / d_model})
PE_{pos, 2i+1} = cos(pos / 10000^{2i / d_model})
本実装では学習するパラメータが存在しないが、BERT では Positional Encoding を変数にしてしまって学習で獲得する
'''
def forward(self, inputs):
batch_size, max_length, depth = inputs.size()
depth_counter = torch.arange(depth) // 2 * 2 # 0, 0, 2, 2, 4, ...
depth_matrix = torch.pow(10000.0, depth_counter / depth) # [depth]
# cos(x) == sin(x + π/2)
phase = torch.arange(depth) % 2 * math.pi / 2 # 0, π/2, 0, π/2, ...
phase_matrix = phase.unsqueeze(0).expand(max_length, depth) # [max_length, depth]
pos_counter = torch.arange(max_length).unsqueeze(1) # [max_length, 1]
pos_matrix = pos_counter.float() # [max_length, 1]
positional_encoding = torch.sin(pos_matrix / depth_matrix + phase_matrix) # [max_length, depth]
positional_encoding = positional_encoding.unsqueeze(0).expand(batch_size, max_length, depth) # [batch_size, max_length, depth]
return inputs + positional_encoding
Multi-head Attention
Multi-head Attentionは以下の図のような構造になっている。
Query・Key・Value
まずは、Positional Encodingが適用された埋め込みベクトルをQuery・Key・ValueのLinerに入力する。
以下の図ではEncoderMulti-head AttentioなのでQuery・Key・Valueすべての入力が同じだが、DecoderのMulti-head AttentioではQuery・Keyの入力はEncoderの出力となっている。
こうして得られたQuery・Key・Valueは辞書型と概念的に似たような感じで使う。(使い方は内積計算とかだが…)
Multi-head
Query・Key・Valueを任意の数(ここでは8個)に分割して処理する。(その方が良いらしい)
この8個に分割して処理する1つ1つの機構がheadなので、これが複数個あるためMulti-headという。
Scaled Dot-Product Attention
Attentionの肝。
Scaled Dot-Product Attention自体は以下のような構造で、学習可能なパラメータを持たない単なる計算処理となっている。
今回はMulti-headなので、以下のようにScaled Dot-Product Attentionもheadの数だけ分割して処理する。
ただし、以下の図は概念を理解するためにScaled Dot-Product Attentionをhead数分だけ増やしているが、実際は多次元のテンソル([batch_size, head_num, m_length, hidden_dim/head_num]
)として同時に計算されるため、実際にScaled Dot-Product Attentionが複数あるわけではない。
①QueryとKeyの内積
Scaled Dot-Product Attentionの最初の部分であるMatMulはQueryとKeyの内積計算である。
Queryが検索内容を表すベクトルであり、Keyと内積を取ることで、関連度ベクトルが得られる。
(内積はある種の類似度計算として扱うことができる)
関連度ベクトル:$\bf QK^T$
②Scale
関連度をスケーリングする。
スケーリングされた関連度ベクトル:$\frac{\bf QK^T}{\sqrt{d_k}}$
(Option)Mask
Decoderで使用されるMasked Multi-head AttentionではScaled Dot-Product Attention内でMask処理が行われる。
PADトークンを無視したりやDecoderで未来の情報を見れないようにしたりするために、該当箇所の関連度をー♾️にすることで、次のSoftmaxを通った後にその部分はゼロになる。
③Softmax
普通にSoftmaxを適用する。
最終的な関連度ベクトル:$softmax(\frac{\bf QK^T}{\sqrt{d_k}})$
④関連度ベクトルとValueの行列積
関連度ベクトルとValueの行列積を計算することで、ValueからQueryと関連する情報が抽出される。
Scaled Dot-Product Attentionの出力:$softmax(\frac{\bf QK^T}{\sqrt{d_k}})\bf V$
⑤その他の処理( Concat / Linaer)
Head毎に出力されたベクトルを Concatして、Linearに通す。
最終的なMulti-head Attention
import torch
import torch.nn as nn
class MultiheadAttention(nn.Module):
'''
Multi-head Attention のモデルです。
model = MultiheadAttention(
hidden_dim=512,
head_num=8,
dropout_rate=0.1,
)
model(query, memory, mask, training=True)
'''
def __init__(self, hidden_dim: int, head_num: int, dropout_rate: float):
'''
コンストラクタです。
:param hidden_dim: 隠れ層及び出力の次元
head_num の倍数である必要があります。
:param head_num: ヘッドの数
:param dropout_rate: ドロップアウトする確率
'''
super(MultiheadAttention, self).__init__()
self.hidden_dim = hidden_dim
self.head_num = head_num
self.dropout_rate = dropout_rate
# Query, Key, Value, Outputに対する線形変換のための全結合層を定義します
self.q_dense_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.k_dense_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.v_dense_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
# ドロップアウトを適用するレイヤー:ドロップアウトは、指定された確率でランダムに選択した要素を0にすることで、モデルの表現の多様性を増やし、過学習を防ぐことができる
self.attention_dropout_layer = nn.Dropout(dropout_rate)
# 出力層
self.output_dense_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
def forward(self,
input: torch.Tensor,
memory: torch.Tensor,
attention_mask: torch.Tensor,
training: bool,
) -> torch.Tensor:
'''
モデルの実行を行います。
:param input: query のテンソル
:param memory: query に情報を与える memory のテンソル
:param attention_mask: attention weight に適用される mask
shape = [batch_size, 1, q_length, k_length] のものです。
pad 等無視する部分が True となるようなものを指定してください。
:param training: 学習時か推論時かのフラグ
'''
# Query, Key, Valueのレイヤー:この部分が学習可能
## Query, Key, Valueの計算
q = self.q_dense_layer(input) # [batch_size, q_length, hidden_dim]
k = self.k_dense_layer(memory) # [batch_size, m_length, hidden_dim]
v = self.v_dense_layer(memory) # [batch_size, m_length, hidden_dim]
## Query, Key, Valueそれぞれを複数のヘッドに分割(Multi-head)
q = self._split_head(q) # [batch_size, head_num, q_length, hidden_dim/head_num]
k = self._split_head(k) # [batch_size, head_num, m_length, hidden_dim/head_num]
v = self._split_head(v) # [batch_size, head_num, m_length, hidden_dim/head_num]
# Scaled Dot-production:この部分は単なる計算のみで学習パラメータが存在しないため学習不可能
## MatNul
## QueryとKeyの内積を計算して関連度を表すlogitを得ます 内積≒類似度:内積は2つのベクトルが近いほど大きく、90度で0に、反対方向ならマイナスとなるため
logit = torch.matmul(q, k.transpose(-2, -1)) # [batch_size, q_length, k_length]
## Scale
depth = self.hidden_dim // self.head_num
scale = logit / (depth ** -0.5)
## Mask
## mask は pad 部分などが1, 他は0のベクトル
logit += attention_mask.to(torch.float32) * torch.finfo(input.dtype).min
## Softmax
## Softmax関数を用いて関連度を正規化します
attention_weight = nn.functional.softmax(scale, dim=-1)
## ドロップアウトの適用:モデルが学習中の場合にはドロップアウトを適用し、学習中でない場合にはドロップアウトを適用しない
attention_weight = self.attention_dropout_layer(attention_weight) if training else attention_weight
## MatNul
## 重みに基づいてValueから情報を引き出します
## Valueと関連度のベクトル(attention_weight)を行列積を計算することで関連する部分が抽出される
attention_output = torch.matmul(attention_weight, v) # [batch_size, q_length, hidden_dim/head_num]
## Concat
attention_output = self._combine_head(attention_output) # [batch_size, q_length, hidden_dim]
## Linear:この部分も学習可能?
# Attention出力を再度線形変換します
return self.output_dense_layer(attention_output) # [batch_size, q_length, hidden_dim]
def _split_head(self, x: torch.Tensor) -> torch.Tensor:
'''
入力の tensor の hidden_dim の次元をいくつかのヘッドに分割します。
入力 shape: [batch_size, length, hidden_dim] の時
出力 shape: [batch_size, head_num, length, hidden_dim//head_num]
となります。
'''
batch_size, length, hidden_dim = x.size()
x = x.view(batch_size, length, self.head_num, self.hidden_dim // self.head_num)
return x.transpose(1, 2)
def _combine_head(self, x: torch.Tensor) -> torch.Tensor:
'''
入力の tensor の各ヘッドを結合します。 _split_head の逆変換です。
入力 shape: [batch_size, head_num, length, hidden_dim//head_num] の時
出力 shape: [batch_size, length, hidden_dim]
となります。
'''
batch_size, _, length, _ = x.size()
x = x.transpose(1, 2).contiguous()
return x.view(batch_size, length, self.hidden_dim)
class SelfAttention(MultiheadAttention):
def forward(
self,
input: torch.Tensor,
attention_mask: torch.Tensor,
training: bool,
) -> torch.Tensor:
return super().forward(
input=input,
memory=input,
attention_mask=attention_mask,
training=training,
)
Add & Norm
正規化した後に残差接続を行う。
import torch
import torch.nn as nn
class LayerNormalization(nn.Module):
def __init__(self, hidden_dim: int, epsilon: float = 1e-6):
super(LayerNormalization, self).__init__()
self.hidden_dim = hidden_dim
self.epsilon = epsilon
self.scale = nn.Parameter(torch.ones(hidden_dim))
self.bias = nn.Parameter(torch.zeros(hidden_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 計算対象のテンソルxの最後の次元において平均値を計算します
mean = x.mean(dim=-1, keepdim=True)
# 計算対象のテンソルxから平均値を引いて二乗し、平均を取ります
variance = torch.mean(torch.square(x - mean), dim=-1, keepdim=True)
# 正規化されたテンソルを計算します
norm_x = (x - mean) * torch.rsqrt(variance + self.epsilon)
# 正規化されたテンソルにスケールとバイアスを適用します
return norm_x * self.scale + self.bias
class ResidualNormalizationWrapper(nn.Module):
'''
Add & Norm
'''
def __init__(self, layer: nn.Module, dropout_rate: float):
super(ResidualNormalizationWrapper, self).__init__()
self.layer = layer
self.layer_normalization = LayerNormalization(layer.hidden_dim)
self.dropout_layer = nn.Dropout(dropout_rate)
def forward(self, input: torch.Tensor, training: bool, *args, **kwargs) -> torch.Tensor:
# 入力テンソルに対してレイヤー正規化を適用します
tensor = self.layer_normalization(input)
# レイヤーの処理を実行し、ドロップアウトも適用します
tensor = self.layer(tensor, training=training, *args, **kwargs)
tensor = self.dropout_layer(tensor)
# 入力テンソルと処理されたテンソルを足し合わせて残差を計算します
return input + tensor
Feed Forward
ここまで線形層しかないので、活性化関数にReLUを使って非線形変換も行われるようにする。
また、DropOutなども適用しておく。
import torch
import torch.nn as nn
class FeedForwardNetwork(nn.Module):
'''
Multi-head Attention層にはLinearがあるが、活性関数が無く線形変換のみが行われているため、
ここでLinearの活性化関数にReLUを用いることで非線形変換を導入する
'''
def __init__(self, hidden_dim: int, dropout_rate: float):
super(FeedForwardNetwork, self).__init__()
self.hidden_dim = hidden_dim
self.dropout_rate = dropout_rate
# 入力次元数から出力次元数への線形変換層を定義します
self.filter_dense_layer = nn.Linear(hidden_dim, hidden_dim * 4)
# 活性化関数としてReLU関数を定義します
self.activation = nn.ReLU()
# ドロップアウトを適用する層を定義します
self.dropout_layer = nn.Dropout(dropout_rate)
# 出力次元数から入力次元数への線形変換層を定義します
self.output_dense_layer = nn.Linear(hidden_dim * 4, hidden_dim)
def forward(self, input: torch.Tensor, training: bool) -> torch.Tensor:
# 入力テンソルに対して線形変換を適用します
tensor = self.filter_dense_layer(input)
# 活性化関数を適用します
tensor = self.activation(tensor)
# ドロップアウトを適用します
tensor = self.dropout_layer(tensor)
# 再度線形変換を適用します
tensor = self.output_dense_layer(tensor)
return tensor
Encoder
原論文「Attention Is All You Need」の左側の部分である。
import torch
import torch.nn as nn
from typing import List
class Encoder(nn.Module):
def __init__(
self,
vocab_size: int,
hopping_num: int,
head_num: int,
hidden_dim: int,
dropout_rate: float,
max_length: int
) -> None:
super(Encoder, self).__init__()
self.hopping_num = hopping_num
self.head_num = head_num
self.hidden_dim = hidden_dim
self.dropout_rate = dropout_rate
# トークンの埋め込み層を定義します
self.token_embedding = TokenEmbedding(vocab_size, hidden_dim)
# 位置情報の埋め込み層を定義します
self.add_position_embedding = AddPositionalEncoding()
# 入力ドロップアウト層を定義します
self.input_dropout_layer = nn.Dropout(dropout_rate)
# Attentionブロックのリストを初期化します
self.attention_block_list: List[List[nn.Module]] = []
for _ in range(hopping_num):
# Attention層とFFN層を定義し、ResidualNormalizationWrapperでラップします
attention_layer = SelfAttention(hidden_dim, head_num, dropout_rate)
ffn_layer = FeedForwardNetwork(hidden_dim, dropout_rate)
self.attention_block_list.append([
ResidualNormalizationWrapper(attention_layer, dropout_rate),
ResidualNormalizationWrapper(ffn_layer, dropout_rate),
])
# 出力正規化層を定義します
self.output_normalization = LayerNormalization(hidden_dim)
def forward(
self,
input: torch.Tensor,
self_attention_mask: torch.Tensor,
training: bool,
) -> torch.Tensor:
# 入力テンソルに対してトークンの埋め込みを適用します
embedded_input = self.token_embedding(input)
# 埋め込まれた入力に位置情報の埋め込みを加えます
embedded_input = self.add_position_embedding(embedded_input)
# 入力ドロップアウトを適用します
query = self.input_dropout_layer(embedded_input)
# 各Attentionブロックを順番に処理します
for i, layers in enumerate(self.attention_block_list):
attention_layer, ffn_layer = tuple(layers)
# Attention層を適用します
query = attention_layer(query, attention_mask=self_attention_mask, training=training)
# FFN層を適用します
query = ffn_layer(query, training=training)
# 出力正規化を適用します
return self.output_normalization(query)
Decoder
原論文「Attention Is All You Need」の右側の部分である。
最初のMulti-head AttentionではMask処理を適用する。
また、2番目のMulti-head AttentionのQuery・KeyにはEncoderの出力を使用する。
import torch
import torch.nn as nn
from typing import List
class Decoder(nn.Module):
def __init__(
self,
vocab_size: int,
hopping_num: int,
head_num: int,
hidden_dim: int,
dropout_rate: float,
max_length: int
) -> None:
super(Decoder, self).__init__()
self.hopping_num = hopping_num
self.head_num = head_num
self.hidden_dim = hidden_dim
self.dropout_rate = dropout_rate
# トークンの埋め込み層を定義します
self.token_embedding = TokenEmbedding(vocab_size, hidden_dim)
# 位置情報の埋め込み層を定義します
self.add_position_embedding = AddPositionalEncoding()
# 入力ドロップアウト層を定義します
self.input_dropout_layer = nn.Dropout(dropout_rate)
# Attentionブロックのリストを初期化します
self.attention_block_list: List[List[nn.Module]] = []
for _ in range(hopping_num):
# 自己注意層、エンコーダ-デコーダ注意層、FFN層を定義し、ResidualNormalizationWrapperでラップします
self_attention_layer = SelfAttention(hidden_dim, head_num, dropout_rate)
enc_dec_attention_layer = MultiheadAttention(hidden_dim, head_num, dropout_rate)
ffn_layer = FeedForwardNetwork(hidden_dim, dropout_rate)
self.attention_block_list.append([
ResidualNormalizationWrapper(self_attention_layer, dropout_rate),
ResidualNormalizationWrapper(enc_dec_attention_layer, dropout_rate),
ResidualNormalizationWrapper(ffn_layer, dropout_rate),
])
# 出力正規化層を定義します
self.output_normalization = LayerNormalization(hidden_dim)
# 出力線形変換層を定義します
self.output_dense_layer = nn.Linear(hidden_dim, vocab_size, bias=False)
def forward(
self,
input: torch.Tensor,
encoder_output: torch.Tensor,
self_attention_mask: torch.Tensor,
enc_dec_attention_mask: torch.Tensor,
training: bool,
) -> torch.Tensor:
# 入力テンソルに対してトークンの埋め込みを適用します
embedded_input = self.token_embedding(input)
# 埋め込まれた入力に位置情報の埋め込みを加えます
embedded_input = self.add_position_embedding(embedded_input)
# 入力ドロップアウトを適用します
query = self.input_dropout_layer(embedded_input)
# 各Attentionブロックを順番に処理します
for i, layers in enumerate(self.attention_block_list):
self_attention_layer, enc_dec_attention_layer, ffn_layer = tuple(layers)
# 自己注意層を適用します
query = self_attention_layer(query, attention_mask=self_attention_mask, training=training)
# エンコーダ-デコーダ注意層を適用します
query = enc_dec_attention_layer(query, memory=encoder_output,
attention_mask=enc_dec_attention_mask, training=training)
# FFN層を適用します
query = ffn_layer(query, training=training)
# 出力正規化を適用します
query = self.output_normalization(query)
# 出力線形変換を適用します
return self.output_dense_layer(query)
Transformer
原論文「Attention Is All You Need」の構造になるようにEncoderとDecoderを接続し、Transformerを完成させる。
Decoderの出力をLinerとSoftmaxに通すと語彙数次元のベクトルが得られる。
これは各トークンの出現確率を表し、最も高い確率のトークンがEncoderに入力された文章の次に出現する可能性が高いトークンである。
なお、つまり、Encoderに「私は太郎です」Decoderに「I」を与えると出力された確率から次のトークンが「am」になる。すると今度は、Encoderに「私は太郎です」Decoderに「I am」を与えると出力された確率から次のトークンが「Taro」になる。という風に翻訳が進んでいく。
import torch
import torch.nn as nn
from myTransformer.Encoder import Encoder
from myTransformer.Decoder import Decoder
class Transformer(nn.Module):
def __init__(
self,
vocab_size: int,
hopping_num: int = 4,
head_num: int = 8,
hidden_dim: int = 512,
dropout_rate: float = 0.1,
max_length: int = 200
) -> None:
super(Transformer, self).__init__()
self.vocab_size = vocab_size
self.hopping_num = hopping_num
self.head_num = head_num
self.hidden_dim = hidden_dim
self.dropout_rate = dropout_rate
self.max_length = max_length
self.pad_token_id = 0
# エンコーダとデコーダのインスタンスを生成します
self.encoder = Encoder(
vocab_size=vocab_size,
hopping_num=hopping_num,
head_num=head_num,
hidden_dim=hidden_dim,
dropout_rate=dropout_rate,
max_length=max_length
)
self.decoder = Decoder(
vocab_size=vocab_size,
hopping_num=hopping_num,
head_num=head_num,
hidden_dim=hidden_dim,
dropout_rate=dropout_rate,
max_length=max_length
)
def forward(self, encoder_input: torch.Tensor, decoder_input: torch.Tensor, training: bool) -> torch.Tensor:
# エンコーダの入力に対してエンコード処理を行います
enc_attention_mask = self._create_enc_attention_mask(encoder_input)
encoder_output = self.encoder(
encoder_input,
self_attention_mask=enc_attention_mask,
training=training
)
# デコーダの入力に対してデコード処理を行います
dec_self_attention_mask = self._create_dec_self_attention_mask(decoder_input)
decoder_output = self.decoder(
decoder_input,
encoder_output,
self_attention_mask=dec_self_attention_mask,
enc_dec_attention_mask=enc_attention_mask,
training=training
)
return decoder_output
def _create_enc_attention_mask(self, encoder_input: torch.Tensor):
with torch.no_grad():
# パディング部分をマスクするためのテンソルを作成します
pad_array = (encoder_input == self.pad_token_id).unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, length]
return pad_array
def _create_dec_self_attention_mask(self, decoder_input: torch.Tensor):
with torch.no_grad():
batch_size, length = decoder_input.size()
# パディング部分と自己回帰部分をマスクするためのテンソルを作成します
pad_array = (decoder_input == self.pad_token_id).unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, length]
autoregression_array = torch.triu(torch.ones(length, length, dtype=torch.bool), diagonal=1) # 上三角行列(対角成分は0)の論理値テンソル
autoregression_array = autoregression_array.unsqueeze(0).unsqueeze(1) # [1, 1, length, length]
return pad_array | autoregression_array