Tracking Validation Loss and Accuracy on RetinaNet

How to

RetinaNet is a deep learning model developed by Facebook which works very well on various object detections. One obstacle I faced when using the model was when I would like to try to save model only when validation accuracy improved. However, RetinaNet’s default setting will not be able to see the validation accuracy and result in an error. It turns out it was an easy simple line of code but took quite a while to figure out.

The Code

To begin, install RetinaNet by following the instructions from the link below:

https://github.com/fizyr/keras-retinanet

The code that you would have to look at is train.py which should probably be in the directory keras-retinanet/keras_retinanet/bin/train.py. What you would have to modify is your main function.

def main(args=None):
    ...
    ...
    # start training
    training_model.fit_generator(
        generator=train_generator,
        steps_per_epoch=args.steps,
        epochs=args.epochs,
        verbose=1,
        callbacks=callbacks,
    )

The problem with the original code is that during calling the trainingmodel.fitgenerator, there is no validation values for it to track. Therefore, you could modify your into something like below.

def main(args=None):
    ...
    ...
    # start training
    training_model.fit_generator(
        generator=train_generator,
        steps_per_epoch=args.steps,
        epochs=args.epochs,
        verbose=1,
        **validation_steps = args.steps_for_validation,
        validation_data = validation_generator,**
        callbacks=callbacks,
    )

By adding those two lines, you are now able to track the validation loss and accuracy. With this you can now save only the model that performs best on validation accuracy or loss by just simply modifying your callbacks as below.

def create_callback(model, training_model, prediction_model, validation_generator, args):
    callbacks = []
    ...
        ...
            os.path.join(
                args.snapshot_path,
                ...
            verbose = 1
     )
     **save_best_only = True**
...

I hope this blog would be able to save your time when you use RetinaNet and want to save your model only when your validation accuracy or loss improves.

Contact us

Drop us a line and we will get back to you