LoRAについての解説メモ

2023年6月15日

LoRAは全結合層に適用できる。
例えば、Pytorchで実装されたある事前学習済みモデルの全結合層の1つnn.Linear(5,5)にLoRAを適用する場合を考える。

元々の全結合層は下図の水色部分である。

この全結合層の事前学習済みの重みWoldは入力と出力の次元数がどちらも5なので、重みパラメータwold(i,j)の行列として以下のように表される。

Wold=(wold(1,1)wold(1,2)wold(1,3)wold(1,4)wold(1,5)wold(2,1)wold(2,2)wold(2,3)wold(2,4)wold(2,5)wold(3,1)wold(3,2)wold(3,3)wold(3,4)wold(3,5)wold(4,1)wold(4,2)wold(4,3)wold(4,4)wold(4,5)wold(5,1)wold(5,2)wold(5,3)wold(5,4)wold(5,5))

ファインチューニングではこの重みを更新することを目的としているが、この更新作業を数式化するとファインチューニング後の新しい重みWnew=Wold+ΔWとなる。すなわち、ファインチューニングはこの重みの差分ΔWを求めているということになる。

当然、普通にファインチューニングを行うと5×5の重み行列のパラメータ数である25個のパラメータについて計算する必要がある。そこで、LoRAでは図のオレンジ色の2つの全結合層A=nn.Linear(5,2)B=nn.Linear(2,5)を導入することで学習するパラメータを削減している。

LoRAを適用してチューニングする際に、元の事前学習済みモデルの重みを持つ全結合層(水色)はフリーズさせ、パラメータの更新は行わない。代わりに、新たに作成した全結合層A・B(オレンジ色)のパラメータを学習させる。全結合層A・B(オレンジ色)は元の全結合層(水色)の横に配置する。

学習では、「元の全結合層(水色)への入力を普通に元の全結合層(水色)へ与えて得た出力」と、「同じ入力を全結合層A(オレンジ色)に与えてその出力を全結合層B(オレンジ色)に与えて最終的に得られた出力」の2つのベクトルを足し合わせて(結合ではないため次元数はdのまま)次の層へ渡している。バックプロパゲーションの際には、元の全結合層(水色)の重みは固定されているため、全結合層A・B(オレンジ色)の重みのみが学習されることになる。

全結合層B(オレンジ色)は0で初期化されているため、チューニング時の最初は追加した全結合層A・B(オレンジ色)の影響はなく、元々学習したパラメータがそのまま使用される。その後、全結合層A・B(オレンジ色)の重みが更新されるが、元の全結合層(水色)からの出力は一切変化しないため、元の全結合層(水色)の出力と全結合層B(オレンジ色)の出力を足し合わせた最終的な出力には、全結合層A・B(オレンジ色)の重みの更新だけが影響する。すなわち、全結合層A・B(オレンジ色)の重みは「事前学習で獲得した重み」と「チューニング後の重み」との差分となっている。これは、Wnew=Wold+ΔWにおけるΔWに相当する。

ここで、全結合層A・B(オレンジ色)の重みを行列で表すと以下のようになる。

全結合層Aの重みWA=(wA(1,1)wA(1,2)wA(1,3)wA(1,4)wA(1,5)wA(2,1)wA(2,2)wA(2,3)wA(2,4)wA(2,5))

全結合層Bの重みWB=(wB(1,1)wB(1,2)wB(2,1)wB(2,2)wB(3,1)wB(3,2)wB(4,1)wB(4,2)wB(5,1)wB(5,2))

この行列積を計算するとWBWAは5×5の行列となる。

WBWA=(k=12wB(1,k)wA(k,1)k=12wB(1,k)wA(k,2)k=12wB(1,k)wA(k,3)k=12wB(1,k)wA(k,4)k=12wB(1,k)wA(k,5)k=12wB(2,k)wA(k,1)k=12wB(2,k)wA(k,2)k=12wB(2,k)wA(k,3)k=12wB(2,k)wA(k,4)k=12wB(2,k)wA(k,5)k=12wB(3,k)wA(k,1)k=12wB(3,k)wA(k,2)k=12wB(3,k)wA(k,3)k=12wB(3,k)wA(k,4)k=12wB(3,k)wA(k,5)k=12wB(4,k)wA(k,1)k=12wB(4,k)wA(k,2)k=12wB(4,k)wA(k,3)k=12wB(4,k)wA(k,4)k=12wB(4,k)wA(k,5)k=12wB(5,k)wA(k,1)k=12wB(5,k)wA(k,2)k=12wB(5,k)wA(k,3)k=12wB(5,k)wA(k,4)k=12wB(5,k)wA(k,5))


すなわちΔW=WBWAとなる。これを元の全結合層(水色)の重みにWnew=Wold+WBWAと足し合わせれば、ファインチューニング相当の操作ができたことになる。

推論時には、このようにして予め重みを足し合わせておけば、LoRAの全結合層A・B(オレンジ色)は必要ないので、推論にかかる計算コストは変化しない。

また、学習においてVRAMを使用するのは、勾配計算に必要なデータを保持するためであり、フリーズさせた層は推論時と同程度のVRAMしか使用しないため、更新が必要なパラメータ数を削減すれば学習に必要なVRAMサイズが少なく済む。

今回の例では、元々の全結合層のパラメータ数は5×5であり、Aの全結合層のパラメータ数は5×2、Bの全結合層のパラメータ数は2×5なので、学習すべきパラメータ数は25個から20個に削減できている。この場合はそこまで恩恵がないが、例えば元の入出力の次元がd=100であり、LoRAの中間層の次元r=2としたなら1万のパラメータを400に削減できることになり、パラメータ数が多いほど恩恵が大きくなる。なお、rは何でも良いが、1とか2でもさほど大きな性能低下はないらしい。(詳しくは原論文の実験結果を参照すること)

なお、Transformerベースのモデルでは、Multi-head AttentionにおけるScaled Dot-productionを行う前のQuery・Key・Valueの各Linear層などに適用して使うことが多いらしい。
原理的には、全結合層ならどこにでも使えそう…

LoRAのPEFTライブラリを使った適用方法