diff --git a/run.py b/run.py index 3936781..471b673 100644 --- a/run.py +++ b/run.py @@ -241,5 +241,11 @@ def main(): print("Saving model to {}".format(model_checkpoint + ".last")) torch.save(rencoder.state_dict(), model_checkpoint + ".last") + # save to onnx + dummy_input = Variable(torch.randn(params['batch_size'], data_layer.vector_dim).type(torch.float)) + torch.onnx.export(rencoder.float(), dummy_input.cuda() if use_gpu else dummy_input, + model_checkpoint + ".onnx", verbose=True) + print("ONNX model saved to {}!".format(model_checkpoint + ".onnx")) + if __name__ == '__main__': main()