Transformerについてのメモ

2023年6月15日

全体像

原論文「Attention Is All You Need」にある図。

Attention Is All You Need

以下ではTransfomerのコードと共に構造を説明する。
なお、以降で登場する図や例は、「私は太郎です」を「I am Taro」と英訳する際のTransformerの様子を図示したものである。
ただし、BOSトークンやEOSトークンの話は無しで考えているので注意すること。
また、内積/行列積・ベクトル/行列の書き分けが曖昧になっている箇所があるので適宜読み替えること。

大雑把な全体の構成要素は以下の5つとなっている。

  • Token Embedding
  • Positional Encoding
  • Multi-head Attention
  • Add & Norm
  • Feed Forward
全体の詳細な図

なお、全体の構造を1つの図に書き込んだら、文字が小さくなりすぎた。
見づらいので細かいところまで見たければ画像をクリックして別タブで見ることを推奨。

Token Embedding

Tokenizerで得られた文章の各トークンを埋め込み表現に変換するための処理。
Transformerでは通常は512次元、BERTだと768次元などの埋め込みベクトルを各トークン毎に得る。
また、図ではトークン長を200トークンとして考えているので、短い文章はPADトークンで埋められている。

import torch
import torch.nn as nn

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super(TokenEmbedding, self).__init__()

        # 語彙のサイズと埋め込みの次元数を設定
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim

        # パディングトークンのIDを0として設定
        self.pad_token_id = 0

        # トークンの埋め込み行列を作成
        self.lookup_table = nn.Embedding(
            num_embeddings=self.vocab_size,
            embedding_dim=self.embedding_dim,
            padding_idx=self.pad_token_id,
        )

        # 埋め込み行列の重みを初期化
        self.lookup_table.weight.data.normal_(0., self.embedding_dim ** -0.5)

    def forward(self, input: torch.Tensor) -> torch.Tensor:

        # パディングトークン以外の位置を示すマスクテンソルを作成
        mask = input.ne(self.pad_token_id).float()

        # トークンの埋め込みを取得
        embedding = self.lookup_table(input)

        # パディングトークンの位置に対応する埋め込みを0にする
        embedding *= mask.unsqueeze(-1)

        # 埋め込みの次元数をスケーリングする
        return embedding * self.embedding_dim ** 0.5

Positional Encoding

各トークンの文書内での位置をToken Embeddingで作成したベクトルに埋め込む。
これによって文章内の単語の順序関係を処理できるようになる。

例えば、「私は犬より猫が好きです」と「私は猫より犬が好きです」は意味が全く違うが、出現するトークンは同じため、Token Embeddingだけではそれぞれの犬・猫の埋め込みが全く同じになってしまう。Positional Encodingを施せば、それぞれの犬・猫の埋め込みは別のベクトルになる。

Positional Encodingの詳細と元の埋め込みの意味が壊れてしまわないのか?

位置埋め込みを使う場合、Transformerへの入力は次のようになる。

$$
\mathbf{x}_i = \mathbf{e}_i + \mathbf{p}_i
$$

ここで、

– $\mathbf{e}_i$ は位置 $i$ にある単語の埋め込み
– $\mathbf{p}_i$ は位置 $i$ の位置埋め込み
– $\mathbf{x}_i$ はTransformerに渡される入力

である。

たとえば、「猫」という単語が文中の4番目に現れた場合、単純化すると次のような処理になる。

$$
\text{入力ベクトル}
=
\text{「猫」の意味ベクトル}
+
\text{「4番目」の位置ベクトル}
$$

これにより、同じ「猫」という単語でも、現れる場所によって少し異なる入力ベクトルになる。

ここで疑問が生じる。

> 単語の埋め込みに位置情報を足したら、元の意味が崩れてしまうのではないか。

結論から言えば、完全に元の形のまま保存されるわけではない。
しかし、Transformerはそれで問題なく学習できる。

理由は、Transformerが最初から

$$
\mathbf{e}_i + \mathbf{p}_i
$$

という混ざった入力を前提として学習するからである。

位置埋め込みは、学習後に突然加えられるノイズではない。
モデルは学習の初期段階から、意味情報と位置情報が重なったベクトルを受け取る。

そのうえで、後段の重みを調整する。

したがって、モデルは次第に、

– どの成分が単語の意味に関係するか
– どの成分が位置に関係するか
– どの場面で位置情報を強く使うべきか

を学習していく。

単語埋め込みは、通常数百次元から数千次元のベクトルである。

たとえば、非常に単純化して3次元で考える。

「猫」の単語埋め込みが次のようなベクトルだったとする。

$$
\mathbf{e}_{\text{猫}}
=
\begin{bmatrix}
0.8 \\
0.1 \\
-0.3
\end{bmatrix}
$$

また、4番目の位置埋め込みが次のようなベクトルだったとする。

$$
\mathbf{p}_4
=
\begin{bmatrix}
0.05 \\
-0.02 \\
0.04
\end{bmatrix}
$$

このとき、Transformerに入力されるベクトルは次のようになる。

$$
\mathbf{x}_4
=
\mathbf{e}_{\text{猫}}
+
\mathbf{p}_4
=
\begin{bmatrix}
0.85 \\
0.08 \\
-0.26
\end{bmatrix}
$$

元のベクトルから少し変化しているが、「猫」に関する特徴が完全に消えているわけではない。

実際のTransformerでは、ベクトルは数百次元以上ある。
そのため、単語の意味と位置のような複数種類の情報を、同じベクトル空間に重ねて保持できる。情報は重なっているが、読み取り方を学べば必要な情報を取り出せる。

Transformer内部では、入力ベクトルに対して線形変換が何度も行われる。

Self-Attentionでは、入力からQuery、Key、Valueを作る。

$$
\mathbf{q}_i = W_Q \mathbf{x}_i
$$

$$
\mathbf{k}_i = W_K \mathbf{x}_i
$$

$$
\mathbf{v}_i = W_V \mathbf{x}_i
$$

ここで、

$$
\mathbf{x}_i = \mathbf{e}_i + \mathbf{p}_i
$$

なので、Queryについて展開すると、

$$
\mathbf{q}_i
=
W_Q(\mathbf{e}_i + \mathbf{p}_i)
$$

$$
\mathbf{q}_i
=
W_Q\mathbf{e}_i + W_Q\mathbf{p}_i
$$

となる。

つまり、単語の意味と位置の情報は、線形変換後もそれぞれ別の成分として残る。

学習によって、Transformerは次のような使い分けを獲得できる。

– 単語の意味を重視するAttention Head
– 近くの単語を重視するAttention Head
– 文頭付近を重視するAttention Head
– 特定の位置関係を重視するAttention Head

意味と位置を完全に分離して持つ必要はない。
必要に応じて、異なる方向へ射影できればよい。

なぜ [0, 0, 1] のような表現では駄目なのか

位置を表すだけなら、3番目を次のように表す方法も考えられる。

$$
\mathbf{p}_3
=
\begin{bmatrix}
0 \\
0 \\
1
\end{bmatrix}
$$

これはone-hot表現と呼ばれる。

位置1、位置2、位置3をそれぞれ次のように表す。

$$
\mathbf{p}_1
=
\begin{bmatrix}
1 \\
0 \\
0
\end{bmatrix},
\qquad
\mathbf{p}_2
=
\begin{bmatrix}
0 \\
1 \\
0
\end{bmatrix},
\qquad
\mathbf{p}_3
=
\begin{bmatrix}
0 \\
0 \\
1
\end{bmatrix}
$$

この方法でも、位置を識別すること自体は可能である。

しかし、Transformerの位置表現としては不便である。

one-hot表現では、最大系列長がベクトルの次元数になる。

最大系列長が512なら、位置ベクトルは512次元必要である。

最大系列長が32768なら、位置ベクトルも32768次元必要になる。

長い文章を扱うほど、位置表現そのものが大きくなる。

一方、sinとcosを使う方法では、位置ベクトルの次元数を固定できる。

たとえば埋め込み次元が768なら、位置1でも位置10000でも768次元のままである。

one-hot表現では、位置1と位置2は直交する。

$$
\mathbf{p}_1^\mathsf{T}\mathbf{p}_2 = 0
$$

位置1と位置1000も、同様に直交する。

$$
\mathbf{p}_1^\mathsf{T}\mathbf{p}_{1000} = 0
$$

つまり、位置1から見れば、

– 位置2
– 位置1000

は同じくらい異なる位置に見える。

しかし、自然言語では距離が重要である。

– 一つ前の単語
– 二つ後の単語
– 数十トークン前の単語
– 数百トークン離れた単語

では、役割が異なることが多い。

one-hot表現は、位置を区別できるが、距離の構造を自然には持っていない。

自然言語では、「一つ前を見る」「二つ後を見る」という規則が重要である。

しかし、one-hot表現では、

$$
\mathbf{p}_2 – \mathbf{p}_1
$$

$$
\mathbf{p}_{101} – \mathbf{p}_{100}
$$

に共通した単純な構造がない。

位置1から位置2への移動と、位置100から位置101への移動が、別々の変化に見える。

そのため、モデルは位置ごとに個別の関係を学習しやすい。

位置に依存せず、

> 一つ右に進む

という共通規則を表現できる方が望ましい。

sinとcosを用いる理由

Transformerの原論文では、sinとcosを使った位置埋め込みが提案されている。

式は次のようになる。

$$
PE(pos, 2i)
=
\sin\left(
\frac{pos}{10000^{2i/d}}
\right)
$$

$$
PE(pos, 2i+1)
=
\cos\left(
\frac{pos}{10000^{2i/d}}
\right)
$$

ここで、

– $pos$ はトークンの位置
– $i$ は埋め込みベクトル内の次元番号
– $d$ は埋め込み次元数

である。

難しそうに見えるが、考え方は単純である。

位置ごとに、少しずつ異なる波の値を割り当てる。

sinとcosを使う理由は、主に次の三つである。

1. 近い位置が似たベクトルになる
2. 複数の距離スケールを表現できる
3. 相対位置を扱いやすい

それぞれを見ていく。

sinとcosは、入力が少し変わると出力も少し変わる。

たとえば、

$$
\sin(10)
$$

$$
\sin(11)
$$

は異なる値だが、極端に離れた値ではない。

そのため、

$$
PE(10)
$$

$$
PE(11)
$$

も比較的似たベクトルになる。

一方、

$$
PE(10)
$$

$$
PE(1000)
$$

は、より異なるベクトルになりやすい。

これは自然言語に適している。

文章において、隣り合う単語と数百トークン離れた単語は、同じ位置関係ではない。

sinとcosを使うことで、位置表現に距離感を持たせられる。

sinとcosは、同じ形を繰り返す周期関数である。

たとえば、次のような波を考える。

$$
\sin(\theta)
$$

この波は、角度が $2\pi$ 増えると元の値に戻る。

$$
\sin(\theta)
=
\sin(\theta + 2\pi)
$$

つまり、一種類のsinだけで位置を表そうとすると、離れた位置でも同じ値になる可能性がある。

たとえば、位置をそのまま角度として使う単純な例では、

$$
\sin(0)
=
\sin(2\pi)
=
\sin(4\pi)
=
0
$$

となる。

これでは、位置0、位置$2\pi$、位置$4\pi$を区別できない。

位置埋め込みでは、この問題を避けるために、**異なる速さで変化する複数のsinとcos**を組み合わせる。

周期とは、波が一周して同じ値の並びに戻るまでの長さである。

たとえば、次の関数を考える。

$$
\sin(pos)
$$

この関数は、位置が約$6.28$進むと一周する。

$$
2\pi \approx 6.28
$$

一方、次の関数は、よりゆっくり変化する。

$$
\sin\left(\frac{pos}{10}\right)
$$

この場合、sinの中身が$2\pi$増えるまでには、位置が約$62.8$進む必要がある。

$$
20\pi \approx 62.8
$$

さらに、次の関数はもっとゆっくり変化する。

$$
\sin\left(\frac{pos}{100}\right)
$$

この波が一周するには、位置が約$628$進む必要がある。

$$
200\pi \approx 628
$$

つまり、分母が大きいほど波はゆっくり変化し、周期は長くなる。

位置埋め込みでは、短い周期の波と長い周期の波を同時に使う。

概念的には、次のような複数の波を用意する。

$$
\sin(pos)
$$

$$
\sin\left(\frac{pos}{10}\right)
$$

$$
\sin\left(\frac{pos}{100}\right)
$$

$$
\sin\left(\frac{pos}{1000}\right)
$$

短い周期の波は、位置が少し変わっただけでも値が大きく変化する。

そのため、

– 一つ隣の位置
– 二つ先の位置
– 数トークン程度の違い

を細かく区別しやすい。

一方、長い周期の波は、近い位置では値がほとんど変わらない。

しかし、文章全体の中で大きく離れた位置を区別するのに役立つ。

たとえば、

– 文の前半か後半か
– 数十トークン離れているか
– 数百トークン離れているか

といった大きなスケールの違いを表しやすい。

短い周期の波は細かい目盛り、長い周期の波は大きな目盛りとして機能する。

単純な例として、次の二つの波を使う。

$$
\sin(pos)
$$

$$
\sin\left(\frac{pos}{10}\right)
$$

位置3と位置9を比較する。

短い周期の波では、

$$
\sin(3)
$$

$$
\sin(9)
$$

は異なる値になる。

長い周期の波でも、

$$
\sin\left(\frac{3}{10}\right)
$$

$$
\sin\left(\frac{9}{10}\right)
$$

は異なる値になる。

このように、位置ごとに複数の波の値を並べる。

$$
PE(3)
=
\begin{bmatrix}
\sin(3) \\
\cos(3) \\
\sin(3/10) \\
\cos(3/10)
\end{bmatrix}
$$

$$
PE(9)
=
\begin{bmatrix}
\sin(9) \\
\cos(9) \\
\sin(9/10) \\
\cos(9/10)
\end{bmatrix}
$$

一つの波だけを見るよりも、複数の周期を組み合わせた方が、位置を区別しやすい。

sinだけでも、位置に応じて値は変化する。

しかし、sinだけでは異なる位置が同じ値になる場合がある。

たとえば、

$$
\sin(\theta)
=
\sin(\pi – \theta)
$$

である。

具体的には、

$$
\sin\left(\frac{\pi}{6}\right)
=
\sin\left(\frac{5\pi}{6}\right)
=
\frac{1}{2}
$$

となる。

つまり、sinの値だけを見ると、

$$
\frac{\pi}{6}
$$

$$
\frac{5\pi}{6}
$$

を区別できない。

そこで、cosも組み合わせる。

$$
\left(
\sin\theta,
\cos\theta
\right)
$$

先ほどの例では、

$$
\left(
\sin\frac{\pi}{6},
\cos\frac{\pi}{6}
\right)
=
\left(
\frac{1}{2},
\frac{\sqrt{3}}{2}
\right)
$$

である。

一方、

$$
\left(
\sin\frac{5\pi}{6},
\cos\frac{5\pi}{6}
\right)
=
\left(
\frac{1}{2},
-\frac{\sqrt{3}}{2}
\right)
$$

となる。

sinの値は同じだが、cosの値が異なる。

そのため、sinとcosをペアで使えば、位置をより正確に区別できる。

sinとcosの組は、円周上の一点として考えられる。

$$
x = \cos\theta
$$

$$
y = \sin\theta
$$

角度$\theta$が変わると、点は円周上を移動する。

たとえば、

$$
\theta = 0
$$

では、

$$
(\sin\theta,\cos\theta)
=
(0,1)
$$

となる。

$$
\theta = \frac{\pi}{2}
$$

では、

$$
(\sin\theta,\cos\theta)
=
(1,0)
$$

となる。

$$
\theta = \pi
$$

では、

$$
(\sin\theta,\cos\theta)
=
(0,-1)
$$

となる。

つまり、sinとcosを組み合わせると、位置を円周上の向きとして表現できる。

Transformerのsin・cos位置埋め込みでは、次元ごとに異なる周期を使う。

式は次のようになる。

$$
PE(pos, 2i)
=
\sin\left(
\frac{pos}{10000^{2i/d}}
\right)
$$

$$
PE(pos, 2i+1)
=
\cos\left(
\frac{pos}{10000^{2i/d}}
\right)
$$

ここで、

– $pos$ はトークンの位置
– $i$ は次元の番号
– $d$ は埋め込み次元数

である。

重要なのは、次元ごとに分母が変わる点である。

$$
10000^{2i/d}
$$

分母が小さい次元では、波が速く変化する。

分母が大きい次元では、波がゆっくり変化する。

概念的には、次のようなベクトルになる。

$$
PE(pos)
=
\begin{bmatrix}
\sin(pos) \\
\cos(pos) \\
\sin(pos/10) \\
\cos(pos/10) \\
\sin(pos/100) \\
\cos(pos/100) \\
\sin(pos/1000) \\
\cos(pos/1000)
\end{bmatrix}
$$

実際には、分母は10倍ずつ増えるわけではない。
埋め込み次元全体にわたって、滑らかに変化する。

しかし、考え方は同じである。

各位置には、複数のsinとcosの値が割り当てられる。

たとえば、非常に単純化すると、位置3には次のような値が対応する。

$$
PE(3)
=
\begin{bmatrix}
\sin(3) \\
\cos(3) \\
\sin(3/10) \\
\cos(3/10) \\
\sin(3/100) \\
\cos(3/100)
\end{bmatrix}
$$

位置4には次のような値が対応する。

$$
PE(4)
=
\begin{bmatrix}
\sin(4) \\
\cos(4) \\
\sin(4/10) \\
\cos(4/10) \\
\sin(4/100) \\
\cos(4/100)
\end{bmatrix}
$$

位置3と位置4では、各成分が少しずつ異なる。

この値の組み合わせが、位置ごとの特徴として機能する。

一つの波だけでは同じ値が繰り返されても、複数の周期を組み合わせれば、完全に同じ組み合わせになる位置は現れにくい。

次の文を考える。

> 私は犬より猫が好きです

ここで、「犬」は3番目、「猫」は5番目にあるとする。

「犬」には位置3の位置埋め込みが加算される。

$$
\mathbf{x}_{\text{犬}}
=
\mathbf{e}_{\text{犬}}
+
PE(3)
$$

「猫」には位置5の位置埋め込みが加算される。

$$
\mathbf{x}_{\text{猫}}
=
\mathbf{e}_{\text{猫}}
+
PE(5)
$$

一方、次の文では位置が逆になる。

> 私は猫より犬が好きです

この場合、「猫」には位置3の位置埋め込みが加算される。

$$
\mathbf{x}_{\text{猫}}
=
\mathbf{e}_{\text{猫}}
+
PE(3)
$$

「犬」には位置5の位置埋め込みが加算される。

$$
\mathbf{x}_{\text{犬}}
=
\mathbf{e}_{\text{犬}}
+
PE(5)
$$

同じ「犬」と「猫」という単語でも、位置埋め込みが異なる。

そのため、Transformerは、

– 「犬」が「より」の前にある
– 「猫」が「より」の後にある

という違いを利用できる。

sinとcosにより相対的な位置関係をある程度扱える

sinとcosを使う利点は、位置を区別できることだけではない。

もう一つ重要なのは、**ある位置からどれだけ離れているか**という相対位置も扱いやすい点である。

文章では、単語が文中の何番目にあるかだけでなく、他の単語からどれだけ離れているかが重要になる。

たとえば、

– 一つ前の単語
– 二つ後ろの単語
– 数トークン先にある単語
– 離れた場所にある関連語

といった関係である。

自然言語では、絶対的な位置よりも、このような相対的な距離が役立つ場面が多い。

sinとcosには、次の加法定理がある。

$$
\sin(a+b)
=
\sin a \cos b
+
\cos a \sin b
$$

$$
\cos(a+b)
=
\cos a \cos b

\sin a \sin b
$$

ここで、現在位置を $pos$、そこからの移動量を $k$ とする。

すると、

$$
\sin(pos+k)
=
\sin(pos)\cos(k)
+
\cos(pos)\sin(k)
$$

$$
\cos(pos+k)
=
\cos(pos)\cos(k)

\sin(pos)\sin(k)
$$

となる。

この式が意味しているのは、

> 位置 $pos+k$ の表現は、位置 $pos$ の表現に対して、距離 $k$ に応じた共通の変換を加えることで求められる

ということである。

つまり、位置10から位置11へ進む場合も、位置100から位置101へ進む場合も、どちらも「一つ先へ進む」という同じ種類の変化として扱える。

前節で説明したように、sinとcosのペアは円周上の一点として考えられる。

$$
(\sin\theta,\cos\theta)
$$

位置が一つ進むと、円周上の点も少し回転する。

たとえば、位置 $pos$ から位置 $pos+1$ へ進む場合、点は一定の角度だけ回転する。

位置 $pos$ から位置 $pos+5$ へ進む場合は、より大きく回転する。

重要なのは、回転量が現在位置ではなく、移動距離によって決まることである。

たとえば、

– 位置3から位置4への移動
– 位置30から位置31への移動
– 位置300から位置301への移動

は、いずれも「一つ先へ進む」という同じ回転になる。

このように、sinとcosを使うと、位置の差を共通の規則として表現できる。

実際の位置埋め込みでは、複数の周期のsinとcosを使う。

つまり、位置が一つ進んだとき、それぞれの成分は異なる速さで回転する。

概念的には、次のような複数の円が同時に動く。

– 短い周期の成分は、大きく回転する
– 長い周期の成分は、少しだけ回転する

たとえば、次の二つの成分を考える。

$$
\sin(pos)
$$

$$
\sin\left(\frac{pos}{100}\right)
$$

位置が1増えた場合、最初の成分では角度が1増える。

一方、二つ目の成分では角度が$1/100$しか増えない。

つまり、短い周期の成分は細かな位置変化に敏感であり、長い周期の成分は大きな位置関係をゆっくり表す。

この複数の回転を組み合わせることで、Transformerは、

– 一つ隣にある
– 数トークン離れている
– 数十トークン離れている
– かなり遠くにある

といった距離の違いを複数の粒度で扱える。

one-hot表現でも、位置を区別すること自体はできる。

たとえば、

$$
\mathbf{p}_3
=
\begin{bmatrix}
0 \\
0 \\
1 \\
0
\end{bmatrix}
$$

$$
\mathbf{p}_4
=
\begin{bmatrix}
0 \\
0 \\
0 \\
1
\end{bmatrix}
$$

とすれば、位置3と位置4は異なるベクトルになる。

しかし、この表現では、

– 位置3から位置4への移動
– 位置100から位置101への移動

が同じ種類の変化として自然には表現されない。

one-hot表現は、位置を別々のラベルとして扱う。

一方、sinとcosによる位置埋め込みは、位置を連続的な座標として扱う。

この違いが、相対位置を扱ううえで重要になる。

次の文を考える。

> 私は犬より猫が好きです

この文では、「犬」は「より」の一つ前にある。

また、「猫」は「より」の一つ後ろにある。

Transformerが文の意味を理解するには、単に「犬」「猫」「より」という単語が含まれていると知るだけでは足りない。

重要なのは、

– 「犬」が「より」の前にある
– 「猫」が「より」の後ろにある

という位置関係である。

一方、次の文では関係が逆になる。

> 私は猫より犬が好きです

この場合、

– 「猫」が「より」の前にある
– 「犬」が「より」の後ろにある

という位置関係になる。

sinとcosによる位置埋め込みは、こうした「何番目か」だけでなく、「どちらが前にあり、どれだけ離れているか」という関係も扱いやすくする。

ここまでの内容を、文章全体に当てはめて考える。

次の文をトークン単位に分ける。

> 私は犬より猫が好きです

簡略化のため、次のように分割する。

| 位置 | トークン |
|—:|—|
| 1 | 私 |
| 2 | は |
| 3 | 犬 |
| 4 | より |
| 5 | 猫 |
| 6 | が |
| 7 | 好き |
| 8 | です |

各トークンには、単語埋め込みと位置埋め込みが加算される。

$$
\mathbf{x}_i
=
\mathbf{e}_i
+
PE(i)
$$

したがって、「犬」の入力ベクトルは次のようになる。

$$
\mathbf{x}_{\text{犬}}
=
\mathbf{e}_{\text{犬}}
+
PE(3)
$$

「猫」の入力ベクトルは次のようになる。

$$
\mathbf{x}_{\text{猫}}
=
\mathbf{e}_{\text{猫}}
+
PE(5)
$$

「より」の入力ベクトルは次のようになる。

$$
\mathbf{x}_{\text{より}}
=
\mathbf{e}_{\text{より}}
+
PE(4)
$$

このとき、Transformerは単語そのものの意味だけでなく、

– 「犬」が位置3にある
– 「より」が位置4にある
– 「猫」が位置5にある

という情報も利用できる。

さらに、前節で説明したように、位置埋め込みには相対的な関係も反映される。

そのため、

– 「犬」は「より」の一つ前にある
– 「猫」は「より」の一つ後ろにある

という関係も捉えやすい。

次に、次の文を考える。

> 私は猫より犬が好きです

トークンの並びは次のようになる。

| 位置 | トークン |
|—:|—|
| 1 | 私 |
| 2 | は |
| 3 | 猫 |
| 4 | より |
| 5 | 犬 |
| 6 | が |
| 7 | 好き |
| 8 | です |

この場合、「猫」の入力ベクトルは次のようになる。

$$
\mathbf{x}_{\text{猫}}
=
\mathbf{e}_{\text{猫}}
+
PE(3)
$$

「犬」の入力ベクトルは次のようになる。

$$
\mathbf{x}_{\text{犬}}
=
\mathbf{e}_{\text{犬}}
+
PE(5)
$$

先ほどの文と同じ単語が使われているが、単語埋め込みと位置埋め込みの組み合わせが変わる。

最初の文では、

$$
\mathbf{e}_{\text{犬}} + PE(3)
$$

$$
\mathbf{e}_{\text{猫}} + PE(5)
$$

である。

二つ目の文では、

$$
\mathbf{e}_{\text{猫}} + PE(3)
$$

$$
\mathbf{e}_{\text{犬}} + PE(5)
$$

である。

この違いにより、Transformerは二つの文を区別できる。

import torch
import torch.nn as nn
import math

class AddPositionalEncoding(nn.Module):
    '''
    入力テンソルに対し、位置の情報を付与して返すレイヤー
    PE_{pos, 2i}   = sin(pos / 10000^{2i / d_model})
    PE_{pos, 2i+1} = cos(pos / 10000^{2i / d_model})

    本実装では学習するパラメータが存在しないが、BERT では Positional Encoding を変数にしてしまって学習で獲得する
    '''
    def forward(self, inputs):
        batch_size, max_length, depth = inputs.size()

        depth_counter = torch.arange(depth) // 2 * 2  # 0, 0, 2, 2, 4, ...
        depth_matrix = torch.pow(10000.0, depth_counter / depth)  # [depth]

        # cos(x) == sin(x + π/2)
        phase = torch.arange(depth) % 2 * math.pi / 2  # 0, π/2, 0, π/2, ...
        phase_matrix = phase.unsqueeze(0).expand(max_length, depth)  # [max_length, depth]

        pos_counter = torch.arange(max_length).unsqueeze(1)  # [max_length, 1]
        pos_matrix = pos_counter.float()  # [max_length, 1]

        positional_encoding = torch.sin(pos_matrix / depth_matrix + phase_matrix)  # [max_length, depth]
        positional_encoding = positional_encoding.unsqueeze(0).expand(batch_size, max_length, depth)  # [batch_size, max_length, depth]

        return inputs + positional_encoding

Multi-head Attention

Multi-head Attentionは以下の図のような構造になっている。

Query・Key・Value

まずは、Positional Encodingが適用された埋め込みベクトルをQuery・Key・ValueのLinerに入力する。
以下の図ではEncoderMulti-head AttentioなのでQuery・Key・Valueすべての入力が同じだが、DecoderのMulti-head AttentioではQuery・Keyの入力はEncoderの出力となっている。

こうして得られたQuery・Key・Valueは辞書型と概念的に似たような感じで使う。(使い方は内積計算とかだが…)

Multi-head

Query・Key・Valueを任意の数(ここでは8個)に分割して処理する。(その方が良いらしい)
この8個に分割して処理する1つ1つの機構がheadなので、これが複数個あるためMulti-headという。

Scaled Dot-Product Attention

Attentionの肝。
Scaled Dot-Product Attention自体は以下のような構造で、学習可能なパラメータを持たない単なる計算処理となっている。

今回はMulti-headなので、以下のようにScaled Dot-Product Attentionもheadの数だけ分割して処理する。
ただし、以下の図は概念を理解するためにScaled Dot-Product Attentionをhead数分だけ増やしているが、実際は多次元のテンソル([batch_size, head_num, m_length, hidden_dim/head_num])として同時に計算されるため、実際にScaled Dot-Product Attentionが複数あるわけではない。

①QueryとKeyの内積

Scaled Dot-Product Attentionの最初の部分であるMatMulはQueryとKeyの内積計算である。

Queryが検索内容を表すベクトルであり、Keyと内積を取ることで、関連度ベクトルが得られる。
(内積はある種の類似度計算として扱うことができる)

関連度ベクトル:$\bf QK^T$

[batch_size, head_num, m_length, m_length]のようなシェープをしている。
つまり、全トークンT×Tの関連度マップだと考えることができる。

②Scale

関連度をスケーリングする。

スケーリングされた関連度ベクトル:$\frac{\bf QK^T}{\sqrt{d_k}}$

(Option)Mask

Decoderで使用されるMasked Multi-head AttentionではScaled Dot-Product Attention内でMask処理が行われる。
PADトークンを無視したりやDecoderで未来の情報を見れないようにしたりするために、該当箇所の関連度を$-\infty$にすることで、次のSoftmaxを通った後にその部分はゼロになる。

③Softmax

普通にSoftmaxを適用する。

最終的な関連度ベクトル:$softmax(\frac{\bf QK^T}{\sqrt{d_k}})$

④関連度ベクトルとValueの行列積

関連度ベクトルとValueの行列積を計算することで、ValueからQueryと関連する情報が抽出される。

Scaled Dot-Product Attentionの出力:$softmax(\frac{\bf QK^T}{\sqrt{d_k}})\bf V$

$O = softmax(\frac{\bf QK^T}{\sqrt{d_k}})\bf V = A V$と表すと、$O_i = \sum_{j=1}^{T} A_{i,j} V_j$となる。
つまり、以下のように$O_i$はi番目のトークンが他のトークン( j )に向ける注意で重み付けして$V_j$から取り出した情報の合計である。
(ただし、$A_{i,j}$はスカラー、$V_j$はベクトルである。)

$
O =
\begin{bmatrix}
A_{1,1} & A_{1,2} & \dots & A_{1,T} \cr
A_{2,1} & A_{2,2} & \dots & A_{2,T} \cr
\vdots & \vdots & \ddots & \vdots \cr
A_{T,1} & A_{T,2} & \dots & A_{T,T}
\end{bmatrix}
\begin{bmatrix}
V_1 \cr
V_2 \cr
\vdots \cr
V_T
\end{bmatrix}
=
\begin{bmatrix}
A_{1,1}V_1 + A_{1,2}V_2 + \dots + A_{1,T}V_T \cr
A_{2,1}V_1 + A_{2,2}V_2 + \dots + A_{2,T}V_T \cr
\vdots \cr
A_{T,1}V_1 + A_{T,2}V_2 + \dots + A_{T,T}V_T
\end{bmatrix}
$

⑤その他の処理( Concat / Linaer)

Head毎に出力されたベクトルを Concatして、Linearに通す。

最終的なMulti-head Attention

import torch
import torch.nn as nn

class MultiheadAttention(nn.Module):
    '''
    Multi-head Attention のモデルです。
    model = MultiheadAttention(
        hidden_dim=512,
        head_num=8,
        dropout_rate=0.1,
    )
    model(query, memory, mask, training=True)
    '''

    def __init__(self, hidden_dim: int, head_num: int, dropout_rate: float):
        '''
        コンストラクタです。
        :param hidden_dim: 隠れ層及び出力の次元
            head_num の倍数である必要があります。
        :param head_num: ヘッドの数
        :param dropout_rate: ドロップアウトする確率
        '''
        super(MultiheadAttention, self).__init__()

        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.dropout_rate = dropout_rate

        # Query, Key, Value, Outputに対する線形変換のための全結合層を定義します
        self.q_dense_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k_dense_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v_dense_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # ドロップアウトを適用するレイヤー:ドロップアウトは、指定された確率でランダムに選択した要素を0にすることで、モデルの表現の多様性を増やし、過学習を防ぐことができる
        self.attention_dropout_layer = nn.Dropout(dropout_rate)

        # 出力層
        self.output_dense_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)





    def forward(self,
                input: torch.Tensor,
                memory: torch.Tensor,
                attention_mask: torch.Tensor,
                training: bool,
        ) -> torch.Tensor:
        '''
        モデルの実行を行います。
        :param input: query のテンソル
        :param memory: query に情報を与える memory のテンソル
        :param attention_mask: attention weight に適用される mask
            shape = [batch_size, 1, q_length, k_length] のものです。
            pad 等無視する部分が True となるようなものを指定してください。
        :param training: 学習時か推論時かのフラグ
        '''


        # Query, Key, Valueのレイヤー:この部分が学習可能

        ## Query, Key, Valueの計算
        q = self.q_dense_layer(input)  # [batch_size, q_length, hidden_dim]
        k = self.k_dense_layer(memory)  # [batch_size, m_length, hidden_dim]
        v = self.v_dense_layer(memory)  # [batch_size, m_length, hidden_dim]

        ## Query, Key, Valueそれぞれを複数のヘッドに分割(Multi-head)
        q = self._split_head(q)  # [batch_size, head_num, q_length, hidden_dim/head_num]
        k = self._split_head(k)  # [batch_size, head_num, m_length, hidden_dim/head_num]
        v = self._split_head(v)  # [batch_size, head_num, m_length, hidden_dim/head_num]


        # Scaled Dot-production:この部分は単なる計算のみで学習パラメータが存在しないため学習不可能

        ## MatNul
        ## QueryとKeyの内積を計算して関連度を表すlogitを得ます 内積≒類似度:内積は2つのベクトルが近いほど大きく、90度で0に、反対方向ならマイナスとなるため
        logit = torch.matmul(q, k.transpose(-2, -1))  # [batch_size, q_length, k_length]

        ## Scale
        depth = self.hidden_dim // self.head_num
        scale = logit / (depth ** -0.5)

        ## Mask
        ## mask は pad 部分などが1, 他は0のベクトル
        logit += attention_mask.to(torch.float32) * torch.finfo(input.dtype).min

        ## Softmax
        ## Softmax関数を用いて関連度を正規化します
        attention_weight = nn.functional.softmax(scale, dim=-1)

        ## ドロップアウトの適用:モデルが学習中の場合にはドロップアウトを適用し、学習中でない場合にはドロップアウトを適用しない
        attention_weight = self.attention_dropout_layer(attention_weight) if training else attention_weight

        ## MatNul
        ## 重みに基づいてValueから情報を引き出します
        ## Valueと関連度のベクトル(attention_weight)を行列積を計算することで関連する部分が抽出される
        attention_output = torch.matmul(attention_weight, v)  # [batch_size, q_length, hidden_dim/head_num]

        ## Concat
        attention_output = self._combine_head(attention_output)  # [batch_size, q_length, hidden_dim]

        ## Linear:この部分も学習可能?
        # Attention出力を再度線形変換します
        return self.output_dense_layer(attention_output) # [batch_size, q_length, hidden_dim]







    def _split_head(self, x: torch.Tensor) -> torch.Tensor:
        '''
        入力の tensor の hidden_dim の次元をいくつかのヘッドに分割します。
        入力 shape: [batch_size, length, hidden_dim] の時
        出力 shape: [batch_size, head_num, length, hidden_dim//head_num]
        となります。
        '''
        batch_size, length, hidden_dim = x.size()
        x = x.view(batch_size, length, self.head_num, self.hidden_dim // self.head_num)
        return x.transpose(1, 2)

    def _combine_head(self, x: torch.Tensor) -> torch.Tensor:
        '''
        入力の tensor の各ヘッドを結合します。 _split_head の逆変換です。
        入力 shape: [batch_size, head_num, length, hidden_dim//head_num] の時
        出力 shape: [batch_size, length, hidden_dim]
        となります。
        '''
        batch_size, _, length, _ = x.size()
        x = x.transpose(1, 2).contiguous()
        return x.view(batch_size, length, self.hidden_dim)











class SelfAttention(MultiheadAttention):
    def forward(
            self,
            input: torch.Tensor,
            attention_mask: torch.Tensor,
            training: bool,
    ) -> torch.Tensor:
        return super().forward(
            input=input,
            memory=input,
            attention_mask=attention_mask,
            training=training,
        )

Add & Norm

正規化した後に残差接続を行う。

import torch
import torch.nn as nn


class LayerNormalization(nn.Module):
    def __init__(self, hidden_dim: int, epsilon: float = 1e-6):
        super(LayerNormalization, self).__init__()
        self.hidden_dim = hidden_dim
        self.epsilon = epsilon
        self.scale = nn.Parameter(torch.ones(hidden_dim))
        self.bias = nn.Parameter(torch.zeros(hidden_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 計算対象のテンソルxの最後の次元において平均値を計算します
        mean = x.mean(dim=-1, keepdim=True)
        # 計算対象のテンソルxから平均値を引いて二乗し、平均を取ります
        variance = torch.mean(torch.square(x - mean), dim=-1, keepdim=True)
        # 正規化されたテンソルを計算します
        norm_x = (x - mean) * torch.rsqrt(variance + self.epsilon)

        # 正規化されたテンソルにスケールとバイアスを適用します
        return norm_x * self.scale + self.bias






class ResidualNormalizationWrapper(nn.Module):
    '''
    Add & Norm
    '''
    def __init__(self, layer: nn.Module, dropout_rate: float):
        super(ResidualNormalizationWrapper, self).__init__()
        self.layer = layer
        self.layer_normalization = LayerNormalization(layer.hidden_dim)
        self.dropout_layer = nn.Dropout(dropout_rate)

    def forward(self, input: torch.Tensor, training: bool, *args, **kwargs) -> torch.Tensor:
        # 入力テンソルに対してレイヤー正規化を適用します
        tensor = self.layer_normalization(input)
        # レイヤーの処理を実行し、ドロップアウトも適用します
        tensor = self.layer(tensor, training=training, *args, **kwargs)
        tensor = self.dropout_layer(tensor)
        # 入力テンソルと処理されたテンソルを足し合わせて残差を計算します
        return input + tensor

Feed Forward

ここまで線形層しかないので、活性化関数にReLUを使って非線形変換も行われるようにする。
また、DropOutなども適用しておく。

import torch
import torch.nn as nn


class FeedForwardNetwork(nn.Module):
    '''
    Multi-head Attention層にはLinearがあるが、活性関数が無く線形変換のみが行われているため、
    ここでLinearの活性化関数にReLUを用いることで非線形変換を導入する
    '''
    def __init__(self, hidden_dim: int, dropout_rate: float):
        super(FeedForwardNetwork, self).__init__()
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate

        # 入力次元数から出力次元数への線形変換層を定義します
        self.filter_dense_layer = nn.Linear(hidden_dim, hidden_dim * 4)
        # 活性化関数としてReLU関数を定義します
        self.activation = nn.ReLU()
    
        # ドロップアウトを適用する層を定義します
        self.dropout_layer = nn.Dropout(dropout_rate)

        # 出力次元数から入力次元数への線形変換層を定義します
        self.output_dense_layer = nn.Linear(hidden_dim * 4, hidden_dim)



    def forward(self, input: torch.Tensor, training: bool) -> torch.Tensor:
        # 入力テンソルに対して線形変換を適用します
        tensor = self.filter_dense_layer(input)
        # 活性化関数を適用します
        tensor = self.activation(tensor)

        # ドロップアウトを適用します
        tensor = self.dropout_layer(tensor)
        
        # 再度線形変換を適用します
        tensor = self.output_dense_layer(tensor)
        return tensor

Encoder

原論文「Attention Is All You Need」の左側の部分である。

Attention Is All You Need

import torch
import torch.nn as nn
from typing import List

class Encoder(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            hopping_num: int,
            head_num: int,
            hidden_dim: int,
            dropout_rate: float,
            max_length: int
    ) -> None:
        super(Encoder, self).__init__()
        self.hopping_num = hopping_num
        self.head_num = head_num
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate

        # トークンの埋め込み層を定義します
        self.token_embedding = TokenEmbedding(vocab_size, hidden_dim)
        # 位置情報の埋め込み層を定義します
        self.add_position_embedding = AddPositionalEncoding()
        # 入力ドロップアウト層を定義します
        self.input_dropout_layer = nn.Dropout(dropout_rate)

        # Attentionブロックのリストを初期化します
        self.attention_block_list: List[List[nn.Module]] = []
        for _ in range(hopping_num):
            # Attention層とFFN層を定義し、ResidualNormalizationWrapperでラップします
            attention_layer = SelfAttention(hidden_dim, head_num, dropout_rate)
            ffn_layer = FeedForwardNetwork(hidden_dim, dropout_rate)
            self.attention_block_list.append([
                ResidualNormalizationWrapper(attention_layer, dropout_rate),
                ResidualNormalizationWrapper(ffn_layer, dropout_rate),
            ])
        # 出力正規化層を定義します
        self.output_normalization = LayerNormalization(hidden_dim)

    def forward(
            self,
            input: torch.Tensor,
            self_attention_mask: torch.Tensor,
            training: bool,
    ) -> torch.Tensor:
        # 入力テンソルに対してトークンの埋め込みを適用します
        embedded_input = self.token_embedding(input)
        # 埋め込まれた入力に位置情報の埋め込みを加えます
        embedded_input = self.add_position_embedding(embedded_input)
        # 入力ドロップアウトを適用します
        query = self.input_dropout_layer(embedded_input)

        # 各Attentionブロックを順番に処理します
        for i, layers in enumerate(self.attention_block_list):
            attention_layer, ffn_layer = tuple(layers)
            # Attention層を適用します
            query = attention_layer(query, attention_mask=self_attention_mask, training=training)
            # FFN層を適用します
            query = ffn_layer(query, training=training)

        # 出力正規化を適用します
        return self.output_normalization(query)

Decoder

原論文「Attention Is All You Need」の右側の部分である。
最初のMulti-head AttentionではMask処理を適用する。
また、2番目のMulti-head AttentionのQuery・KeyにはEncoderの出力を使用する。

Attention Is All You Need

import torch
import torch.nn as nn
from typing import List

class Decoder(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            hopping_num: int,
            head_num: int,
            hidden_dim: int,
            dropout_rate: float,
            max_length: int
    ) -> None:
        super(Decoder, self).__init__()
        self.hopping_num = hopping_num
        self.head_num = head_num
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate

        # トークンの埋め込み層を定義します
        self.token_embedding = TokenEmbedding(vocab_size, hidden_dim)
        # 位置情報の埋め込み層を定義します
        self.add_position_embedding = AddPositionalEncoding()
        # 入力ドロップアウト層を定義します
        self.input_dropout_layer = nn.Dropout(dropout_rate)

        # Attentionブロックのリストを初期化します
        self.attention_block_list: List[List[nn.Module]] = []
        for _ in range(hopping_num):
            # 自己注意層、エンコーダ-デコーダ注意層、FFN層を定義し、ResidualNormalizationWrapperでラップします
            self_attention_layer = SelfAttention(hidden_dim, head_num, dropout_rate)
            enc_dec_attention_layer = MultiheadAttention(hidden_dim, head_num, dropout_rate)
            ffn_layer = FeedForwardNetwork(hidden_dim, dropout_rate)
            self.attention_block_list.append([
                ResidualNormalizationWrapper(self_attention_layer, dropout_rate),
                ResidualNormalizationWrapper(enc_dec_attention_layer, dropout_rate),
                ResidualNormalizationWrapper(ffn_layer, dropout_rate),
            ])
        # 出力正規化層を定義します
        self.output_normalization = LayerNormalization(hidden_dim)
        # 出力線形変換層を定義します
        self.output_dense_layer = nn.Linear(hidden_dim, vocab_size, bias=False)

    def forward(
            self,
            input: torch.Tensor,
            encoder_output: torch.Tensor,
            self_attention_mask: torch.Tensor,
            enc_dec_attention_mask: torch.Tensor,
            training: bool,
    ) -> torch.Tensor:
        # 入力テンソルに対してトークンの埋め込みを適用します
        embedded_input = self.token_embedding(input)
        # 埋め込まれた入力に位置情報の埋め込みを加えます
        embedded_input = self.add_position_embedding(embedded_input)
        # 入力ドロップアウトを適用します
        query = self.input_dropout_layer(embedded_input)

        # 各Attentionブロックを順番に処理します
        for i, layers in enumerate(self.attention_block_list):
            self_attention_layer, enc_dec_attention_layer, ffn_layer = tuple(layers)
            # 自己注意層を適用します
            query = self_attention_layer(query, attention_mask=self_attention_mask, training=training)
            # エンコーダ-デコーダ注意層を適用します
            query = enc_dec_attention_layer(query, memory=encoder_output,
                                            attention_mask=enc_dec_attention_mask, training=training)
            # FFN層を適用します
            query = ffn_layer(query, training=training)

        # 出力正規化を適用します
        query = self.output_normalization(query)
        # 出力線形変換を適用します
        return self.output_dense_layer(query)

Transformer

原論文「Attention Is All You Need」の構造になるようにEncoderとDecoderを接続し、Transformerを完成させる。
Decoderの出力をLinerとSoftmaxに通すと語彙数次元のベクトルが得られる。
これは各トークンの出現確率を表し、最も高い確率のトークンがEncoderに入力された文章の次に出現する可能性が高いトークンである。
なお、つまり、Encoderに「私は太郎です」Decoderに「I」を与えると出力された確率から次のトークンが「am」になる。すると今度は、Encoderに「私は太郎です」Decoderに「I am」を与えると出力された確率から次のトークンが「Taro」になる。という風に翻訳が進んでいく。

Attention Is All You Need

import torch
import torch.nn as nn

from myTransformer.Encoder import Encoder
from myTransformer.Decoder import Decoder



class Transformer(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            hopping_num: int = 4,
            head_num: int = 8,
            hidden_dim: int = 512,
            dropout_rate: float = 0.1,
            max_length: int = 200
    ) -> None:
        super(Transformer, self).__init__()
        self.vocab_size = vocab_size
        self.hopping_num = hopping_num
        self.head_num = head_num
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate
        self.max_length = max_length
        self.pad_token_id = 0

        # エンコーダとデコーダのインスタンスを生成します
        self.encoder = Encoder(
            vocab_size=vocab_size,
            hopping_num=hopping_num,
            head_num=head_num,
            hidden_dim=hidden_dim,
            dropout_rate=dropout_rate,
            max_length=max_length
        )
        self.decoder = Decoder(
            vocab_size=vocab_size,
            hopping_num=hopping_num,
            head_num=head_num,
            hidden_dim=hidden_dim,
            dropout_rate=dropout_rate,
            max_length=max_length
        )

    def forward(self, encoder_input: torch.Tensor, decoder_input: torch.Tensor, training: bool) -> torch.Tensor:
        # エンコーダの入力に対してエンコード処理を行います
        enc_attention_mask = self._create_enc_attention_mask(encoder_input)
        encoder_output = self.encoder(
            encoder_input,
            self_attention_mask=enc_attention_mask,
            training=training
        )

        # デコーダの入力に対してデコード処理を行います
        dec_self_attention_mask = self._create_dec_self_attention_mask(decoder_input)
        decoder_output = self.decoder(
            decoder_input,
            encoder_output,
            self_attention_mask=dec_self_attention_mask,
            enc_dec_attention_mask=enc_attention_mask,
            training=training
        )
        return decoder_output

    def _create_enc_attention_mask(self, encoder_input: torch.Tensor):
        with torch.no_grad():
            # パディング部分をマスクするためのテンソルを作成します
            pad_array = (encoder_input == self.pad_token_id).unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, length]
            return pad_array

    def _create_dec_self_attention_mask(self, decoder_input: torch.Tensor):
        with torch.no_grad():
            batch_size, length = decoder_input.size()

            # パディング部分と自己回帰部分をマスクするためのテンソルを作成します
            pad_array = (decoder_input == self.pad_token_id).unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, length]
            autoregression_array = torch.triu(torch.ones(length, length, dtype=torch.bool), diagonal=1)  # 上三角行列(対角成分は0)の論理値テンソル
            autoregression_array = autoregression_array.unsqueeze(0).unsqueeze(1)  # [1, 1, length, length]

            return pad_array | autoregression_array