PytorchでEarlyStop

2022年10月22日

開発ディレクトリにpytorchtools.pyとして以下のファイルを作成する。

import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = 

使い方

from pytorchtools import EarlyStopping

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def my_train(train_dataset, dev_dataset, optimizer, criterion, epochs, batch_size):
    train_losses = []
    valid_losses = []
    avg_train_losses = []
    avg_valid_losses = [] 

    # ネットワークがある程度固定であれば、高速化させる
    torch.backends.cudnn.benchmark = True

    patience = 20
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    for epoch in range(epochs):
        model.train()

        train_loss = 0.
        #学習データをシャッフル
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)

        for x_, t_ in train_dataloader: #バッチ数分学習を回して、全データ分の結果を得る

            x_ = x_.to(device)
            t_ = t_.to(device)

            optimizer.zero_grad()


            outputs, _ = model(x_)  # BERTに入力
            
            loss = criterion(outputs, t_)  # 損失を計算

            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())



        model.eval()
        dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size = batch_size, shuffle = True)
        for x_, t_ in dev_dataloader:

            x_ = x_.to(device)
            t_ = t_.to(device)

            outputs, _ = model(x_)  # BERTに入力
            
            loss = criterion(outputs, t_)  # 損失を計算

            valid_losses.append(loss.item())

        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = len(str(epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    # load the last checkpoint with the best model
    model.load_state_dict(torch.load('checkpoint.pt'))

    return  model, avg_train_losses, avg_valid_losses