最近想到一個需求

就是如何使用 keras 的自定義ModelCheckpoint儲存loss與val_loss最高的最佳權重?
意思就是當loss 為0.5,val_loss為0.6時,我要視當下這權重的best loss為0.6
就是取高的
因為有些情況下
val_loss在訓練初期會很低
但是實際上train loss還很高
也就是val_loss的低只是單純猜的
因為驗證資料本身比較少
猜對機率較高
所以會產生這種奇怪問題
所以我想要寫一個自定義ModelCheckpoint來解決這問題
也就是儲存權重依據的loss或是acc都是依據train與val目前較差的那一個
這樣才能確保權重對於train與val都有效

然後特別的是
其實我一開始沒有什麼頭緒
所以我就把我的需求打給 chatGPT
沒想到她還真的給我一個函式來測試
雖然是有BUG的
但是我看得懂架構我就可以修改了
而且還真的可以運行

以下是範例程式碼

 

from tensorflow.keras.callbacks import Callback

class CustomModelCheckpoint(Callback):
    def __init__(self, model, filepath, monitors=['loss', 'val_loss'], 
                 verbose=1,save_weights_only=False, 
                 mode='min'):
        super(CustomModelCheckpoint, self).__init__()

        self.filepath = filepath
        self.verbose = verbose
        self.save_weights_only = save_weights_only
        self.best_weights = None
        
        self.mode = mode
        self.model = model
        self.monitors = monitors

        if self.mode == 'min':
            self.best_t_loss = np.Inf
        else:
            self.best_t_loss = -np.Inf
        
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        
        if self.mode == 'min':
            now_Best_loss = -np.Inf
        else:
            now_Best_loss = np.Inf
        
        for monitor in self.monitors:
            if self.mode == 'min':
                if logs.get(monitor) is not None and logs[monitor] > now_Best_loss:
                    now_Best_loss = logs[monitor]
            else:
                if logs.get(monitor) is not None and logs[monitor] < now_Best_loss:
                    now_Best_loss = logs[monitor]
        self.best_weights = None

        if (self.mode == 'min' and now_Best_loss < self.best_t_loss) or \
           (self.mode == 'max' and now_Best_loss > self.best_t_loss):
            if self.verbose > 0:
                print(' best weights improved from %0.5f to %0.5f,'
                    ' saving model to %s' % (self.best_t_loss, now_Best_loss, filepath))
            self.best_t_loss = now_Best_loss
            self.model.save_weights(filepath, overwrite=True)

 

使用方式則是加入callbacks就可以

 

saveBestF = CustomModelCheckpoint(net_final,
                                  monitors=['accuracy', 'val_accuracy'],
                                  filepath=filepath,
                                  verbose=1,
                                  mode='max',
                                  )
callbacks_list = [saveBestF]

history = net_final.fit(
    x=train_generator_C,
    validation_data=val_generator_C,
    epochs=NUM_EPOCHS,

    callbacks=callbacks_list,

    verbose=1,
    )

 

給大家參考囉

同時也感嘆chatGPT的神奇
難道我也要被chatGPT取代了嗎?