这是indexloc提供的服务,不要输入任何密码
Skip to content

Fix save_checkpoint_fn return value #659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ And to successfully resume from a checkpoint:
1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
2. Set ``resume_from_log=True`` with trainer;

We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_discrete_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/offline/test_discrete_bcq.py>`_ by running

.. code-block:: console

Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_fn(epoch, env_step):

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path

Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def stop_fn(mean_rewards):

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path

Expand Down
28 changes: 16 additions & 12 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,41 +117,45 @@ def dist(*logits):
dual_clip=args.dual_clip,
value_clip=args.value_clip,
gae_lambda=args.gae_lambda,
action_space=env.action_space
action_space=env.action_space,
)
# collector
train_collector = Collector(
policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
log_path = os.path.join(args.logdir, args.task, "ppo")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)

def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
"model": policy.state_dict(),
"optim": optim.state_dict(),
}, ckpt_path
)
return ckpt_path

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
ckpt_path = os.path.join(log_path, "checkpoint.pth")
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
policy.load_state_dict(checkpoint["model"])
optim.load_state_dict(checkpoint["optim"])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
Expand All @@ -171,7 +175,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
save_best_fn=save_best_fn,
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
save_checkpoint_fn=save_checkpoint_fn,
)

for epoch, epoch_stat, info in trainer:
Expand All @@ -181,7 +185,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):

assert stop_fn(info["best_reward"])

if __name__ == '__main__':
if __name__ == "__main__":
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
Expand All @@ -197,5 +201,5 @@ def test_ppo_resume(args=get_args()):
test_ppo(args)


if __name__ == '__main__':
if __name__ == "__main__":
test_ppo()
42 changes: 22 additions & 20 deletions test/discrete/test_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_c51(args=get_args()):
hidden_sizes=args.hidden_sizes,
device=args.device,
softmax=True,
num_atoms=args.num_atoms
num_atoms=args.num_atoms,
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = C51Policy(
Expand All @@ -96,15 +96,15 @@ def test_c51(args=get_args()):
args.v_min,
args.v_max,
args.n_step,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta
beta=args.beta,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
Expand All @@ -114,12 +114,12 @@ def test_c51(args=get_args()):
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
log_path = os.path.join(args.logdir, args.task, "c51")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)

def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
Expand All @@ -140,29 +140,31 @@ def test_fn(epoch, env_step):

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
)
pickle.dump(
train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
"model": policy.state_dict(),
"optim": optim.state_dict(),
}, ckpt_path
)
buffer_path = os.path.join(log_path, "train_buffer.pkl")
pickle.dump(train_collector.buffer, open(buffer_path, "wb"))
return ckpt_path

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
ckpt_path = os.path.join(log_path, "checkpoint.pth")
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
policy.load_state_dict(checkpoint["model"])
policy.optim.load_state_dict(checkpoint["optim"])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
buffer_path = os.path.join(log_path, "train_buffer.pkl")
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
Expand All @@ -186,11 +188,11 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
save_best_fn=save_best_fn,
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
save_checkpoint_fn=save_checkpoint_fn,
)
assert stop_fn(result['best_reward'])
assert stop_fn(result["best_reward"])

if __name__ == '__main__':
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
Expand All @@ -214,5 +216,5 @@ def test_pc51(args=get_args()):
test_c51(args)


if __name__ == '__main__':
if __name__ == "__main__":
test_c51(get_args())
38 changes: 20 additions & 18 deletions test/discrete/test_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def noisy_linear(x, y):
"linear_layer": noisy_linear
}, {
"linear_layer": noisy_linear
})
}),
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = RainbowPolicy(
Expand All @@ -109,7 +109,7 @@ def noisy_linear(x, y):
args.v_min,
args.v_max,
args.n_step,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
if args.prioritized_replay:
Expand All @@ -118,7 +118,7 @@ def noisy_linear(x, y):
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta,
weight_norm=True
weight_norm=True,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
Expand All @@ -128,12 +128,12 @@ def noisy_linear(x, y):
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'rainbow')
log_path = os.path.join(args.logdir, args.task, "rainbow")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)

def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
Expand Down Expand Up @@ -164,29 +164,31 @@ def test_fn(epoch, env_step):

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
)
pickle.dump(
train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
"model": policy.state_dict(),
"optim": optim.state_dict(),
}, ckpt_path
)
buffer_path = os.path.join(log_path, "train_buffer.pkl")
pickle.dump(train_collector.buffer, open(buffer_path, "wb"))
return ckpt_path

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
ckpt_path = os.path.join(log_path, "checkpoint.pth")
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
buffer_path = os.path.join(log_path, "train_buffer.pkl")
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
Expand All @@ -210,11 +212,11 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
save_best_fn=save_best_fn,
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
save_checkpoint_fn=save_checkpoint_fn,
)
assert stop_fn(result['best_reward'])
assert stop_fn(result["best_reward"])

if __name__ == '__main__':
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
Expand All @@ -238,5 +240,5 @@ def test_prainbow(args=get_args()):
test_rainbow(args)


if __name__ == '__main__':
if __name__ == "__main__":
test_rainbow(get_args())
30 changes: 17 additions & 13 deletions test/offline/test_discrete_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument('--reward-threshold', type=float, default=None)
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=3e-4)
Expand All @@ -37,7 +37,7 @@ def get_args():
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
Expand Down Expand Up @@ -104,33 +104,37 @@ def test_discrete_bcq(args=get_args()):
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
log_path = os.path.join(args.logdir, args.task, "discrete_bcq")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)

def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
"model": policy.state_dict(),
"optim": optim.state_dict(),
}, ckpt_path
)
return ckpt_path

if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
ckpt_path = os.path.join(log_path, "checkpoint.pth")
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
policy.load_state_dict(checkpoint["model"])
optim.load_state_dict(checkpoint["optim"])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
Expand All @@ -147,11 +151,11 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
save_best_fn=save_best_fn,
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
save_checkpoint_fn=save_checkpoint_fn,
)
assert stop_fn(result['best_reward'])
assert stop_fn(result["best_reward"])

if __name__ == '__main__':
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
Expand Down
Loading