diff --git a/run.py b/run.py index 0de4e04..3936781 100644 --- a/run.py +++ b/run.py @@ -36,6 +36,8 @@ help='if present, decoder\'s last layer will not apply non-linearity function') parser.add_argument('--num_epochs', type=int, default=50, metavar='N', help='maximum number of epochs') +parser.add_argument('--save_every', type=int, default=3, metavar='N', + help='save every N number of epochs') parser.add_argument('--optimizer', type=str, default="momentum", metavar='N', help='optimizer kind: adam, momentum, adagrad or rmsprop') parser.add_argument('--hidden_layers', type=str, default="1024,512,512,128", metavar='N', @@ -229,7 +231,7 @@ def main(): .format(epoch, e_end_time - e_start_time, sqrt(total_epoch_loss/denom))) logger.scalar_summary("Training_RMSE_per_epoch", sqrt(total_epoch_loss/denom), epoch) logger.scalar_summary("Epoch_time", e_end_time - e_start_time, epoch) - if epoch % 3 == 0 or epoch == args.num_epochs - 1: + if epoch % args.save_every == 0 or epoch == args.num_epochs - 1: eval_loss = do_eval(rencoder, eval_data_layer) print('Epoch {} EVALUATION LOSS: {}'.format(epoch, eval_loss)) logger.scalar_summary("EVALUATION_RMSE", eval_loss, epoch)