LoRAについての解説メモ

2023年6月15日

LoRAは全結合層に適用できる。
例えば、Pytorchで実装されたある事前学習済みモデルの全結合層の1つnn.Linear(5,5)にLoRAを適用する場合を考える。

元々の全結合層は下図の水色部分である。

この全結合層の事前学習済みの重み$\bf W_{old}$は入力と出力の次元数がどちらも5なので、重みパラメータ$w_{old}(i, j)$の行列として以下のように表される。

$\bf W_{old}=
\pmatrix{
w_{old}(1, 1) &w_{old}(1, 2) & w_{old}(1, 3) & w_{old}(1, 4) & w_{old}(1, 5) \cr
w_{old}(2, 1) &w_{old}(2, 2) & w_{old}(2, 3) & w_{old}(2, 4) & w_{old}(2, 5) \cr
w_{old}(3, 1) &w_{old}(3, 2) & w_{old}(3, 3) & w_{old}(3, 4) & w_{old}(3, 5) \cr
w_{old}(4, 1) &w_{old}(4, 2) & w_{old}(4, 3) & w_{old}(4, 4) & w_{old}(4, 5) \cr
w_{old}(5, 1) &w_{old}(5, 2) & w_{old}(5, 3) & w_{old}(5, 4) & w_{old}(5, 5) \cr
}$

ファインチューニングではこの重みを更新することを目的としているが、この更新作業を数式化するとファインチューニング後の新しい重み$\bf W_{new}=\bf W_{old}+\bf \Delta W$となる。すなわち、ファインチューニングはこの重みの差分$\bf \Delta W$を求めているということになる。

当然、普通にファインチューニングを行うと5×5の重み行列のパラメータ数である25個のパラメータについて計算する必要がある。そこで、LoRAでは図のオレンジ色の2つの全結合層A=nn.Linear(5,2)B=nn.Linear(2,5)を導入することで学習するパラメータを削減している。

LoRAを適用してチューニングする際に、元の事前学習済みモデルの重みを持つ全結合層(水色)はフリーズさせ、パラメータの更新は行わない。代わりに、新たに作成した全結合層A・B(オレンジ色)のパラメータを学習させる。全結合層A・B(オレンジ色)は元の全結合層(水色)の横に配置する。

学習では、「元の全結合層(水色)への入力を普通に元の全結合層(水色)へ与えて得た出力」と、「同じ入力を全結合層A(オレンジ色)に与えてその出力を全結合層B(オレンジ色)に与えて最終的に得られた出力」の2つのベクトルを足し合わせて(結合ではないため次元数はdのまま)次の層へ渡している。バックプロパゲーションの際には、元の全結合層(水色)の重みは固定されているため、全結合層A・B(オレンジ色)の重みのみが学習されることになる。

全結合層B(オレンジ色)は0で初期化されているため、チューニング時の最初は追加した全結合層A・B(オレンジ色)の影響はなく、元々学習したパラメータがそのまま使用される。その後、全結合層A・B(オレンジ色)の重みが更新されるが、元の全結合層(水色)からの出力は一切変化しないため、元の全結合層(水色)の出力と全結合層B(オレンジ色)の出力を足し合わせた最終的な出力には、全結合層A・B(オレンジ色)の重みの更新だけが影響する。すなわち、全結合層A・B(オレンジ色)の重みは「事前学習で獲得した重み」と「チューニング後の重み」との差分となっている。これは、$\bf W_{new}=\bf W_{old}+\bf \Delta W$における$\bf \Delta W$に相当する。

ここで、全結合層A・B(オレンジ色)の重みを行列で表すと以下のようになる。

全結合層Aの重み$\bf W_{A}=
\pmatrix{
w_{A}(1, 1) &w_{A}(1, 2) & w_{A}(1, 3) & w_{A}(1, 4) & w_{A}(1, 5) \cr
w_{A}(2, 1) &w_{A}(2, 2) & w_{A}(2, 3) & w_{A}(2, 4) & w_{A}(2, 5) \cr
}$

全結合層Bの重み$\bf W_{B}=
\pmatrix{
w_{B}(1, 1) &w_{B}(1, 2) \cr
w_{B}(2, 1) &w_{B}(2, 2) \cr
w_{B}(3, 1) &w_{B}(3, 2) \cr
w_{B}(4, 1) &w_{B}(4, 2) \cr
w_{B}(5, 1) &w_{B}(5, 2) \cr
}$

この行列積を計算すると$\bf W_{B}\bf W_{A}$は5×5の行列となる。

$$\bf W_{B}\bf W_{A}= \\
\pmatrix{
\sum_{k=1}^{2} w_{B}(1, k)w_{A}(k, 1) & \sum_{k=1}^{2} w_{B}(1, k)w_{A}(k, 2) & \sum_{k=1}^{2} w_{B}(1, k)w_{A}(k, 3) & \sum_{k=1}^{2} w_{B}(1, k)w_{A}(k, 4) & \sum_{k=1}^{2} w_{B}(1, k)w_{A}(k, 5) \cr
\sum_{k=1}^{2} w_{B}(2, k)w_{A}(k, 1) & \sum_{k=1}^{2} w_{B}(2, k)w_{A}(k, 2) & \sum_{k=1}^{2} w_{B}(2, k)w_{A}(k, 3) & \sum_{k=1}^{2} w_{B}(2, k)w_{A}(k, 4) & \sum_{k=1}^{2} w_{B}(2, k)w_{A}(k, 5) \cr
\sum_{k=1}^{2} w_{B}(3, k)w_{A}(k, 1) & \sum_{k=1}^{2} w_{B}(3, k)w_{A}(k, 2) & \sum_{k=1}^{2} w_{B}(3, k)w_{A}(k, 3) & \sum_{k=1}^{2} w_{B}(3, k)w_{A}(k, 4) & \sum_{k=1}^{2} w_{B}(3, k)w_{A}(k, 5) \cr
\sum_{k=1}^{2} w_{B}(4, k)w_{A}(k, 1) & \sum_{k=1}^{2} w_{B}(4, k)w_{A}(k, 2) & \sum_{k=1}^{2} w_{B}(4, k)w_{A}(k, 3) & \sum_{k=1}^{2} w_{B}(4, k)w_{A}(k, 4) & \sum_{k=1}^{2} w_{B}(4, k)w_{A}(k, 5) \cr
\sum_{k=1}^{2} w_{B}(5, k)w_{A}(k, 1) & \sum_{k=1}^{2} w_{B}(5, k)w_{A}(k, 2) & \sum_{k=1}^{2} w_{B}(5, k)w_{A}(k, 3) & \sum_{k=1}^{2} w_{B}(5, k)w_{A}(k, 4) & \sum_{k=1}^{2} w_{B}(5, k)w_{A}(k, 5) \cr
}$$


すなわち$\bf \Delta W=\bf W_{B}\bf W_{A}$となる。これを元の全結合層(水色)の重みに$\bf W_{new}=\bf W_{old}+\bf W_{B}\bf W_{A}$と足し合わせれば、ファインチューニング相当の操作ができたことになる。

推論時には、このようにして予め重みを足し合わせておけば、LoRAの全結合層A・B(オレンジ色)は必要ないので、推論にかかる計算コストは変化しない。

また、学習においてVRAMを使用するのは、勾配計算に必要なデータを保持するためであり、フリーズさせた層は推論時と同程度のVRAMしか使用しないため、更新が必要なパラメータ数を削減すれば学習に必要なVRAMサイズが少なく済む。

今回の例では、元々の全結合層のパラメータ数は5×5であり、Aの全結合層のパラメータ数は5×2、Bの全結合層のパラメータ数は2×5なので、学習すべきパラメータ数は25個から20個に削減できている。この場合はそこまで恩恵がないが、例えば元の入出力の次元がd=100であり、LoRAの中間層の次元r=2としたなら1万のパラメータを400に削減できることになり、パラメータ数が多いほど恩恵が大きくなる。なお、rは何でも良いが、1とか2でもさほど大きな性能低下はないらしい。(詳しくは原論文の実験結果を参照すること)

なお、Transformerベースのモデルでは、Multi-head AttentionにおけるScaled Dot-productionを行う前のQuery・Key・Valueの各Linear層などに適用して使うことが多いらしい。
原理的には、全結合層ならどこにでも使えそう…

LoRAのPEFTライブラリを使った適用方法