GPTについてのメモ

2023年6月16日

文書生成

トークン列を与え、GPTの最終層のベクトルのうちEOSトークンより前のベクトルをLinear+Softmaxに入力して語彙数が次元となる確率分布を得る。この内、最も高いトークンが予測された次の単語となる。

前の入力で得た次単語を含めたトークン列を次の入力として与える。
そして、同様に次の単語が予想される。
これを繰り返すことで、文章が生成される。
なお、次の単語を選ぶときに、最も確率が高いものだけではなく、一定確率で2番目・3番目などの確率の単語を選ぶことで、捻りのある文章が生成される可能性がある。

なお、事前学習なども次単語予測を行っているので上図のようにしている。

その他のタスク

文書生成以外では、基本的にEOSトークンに対応するベクトルを用いる。
例えば、文書分類を行う場合は以下のようにしてクラスの予測ができる。

GPTの事前学習

モデル内で以下のように定義されている。
入力と正解ラベルはどちらも同じトークン列である。
例えば、「吾輩 は 猫 で ある」という文章を入力として与え、各トークンに対応する埋め込みからself.lm_headで次単語の確率分布を出力する。ここで予測した各トークンの次単語と実際の正解ラベルの次単語で誤差を計算して学習させる。この際に次単語の予測とその正解データを整形しているのがshift_logitsshift_labelsである。なお、GPTはtransformerのデコーダー部分なのでMasked Mult-Head Attentionがあり、この部分で入力トークンのうち、未来のトークンをカンニングしないようにマスクしている。つまり、「猫」の次単語を予測する際には「吾輩 は 猫」までのトークンは使用するが「で ある」に該当する箇所にはMasked Mult-Head Attentionでマスクされ、未来のトークンにアテンションが向かないようにしているため、カンニングが起こらない。

     transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(lm_logits.device)
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))