这是indexloc提供的服务,不要输入任何密码
Skip to content

"test_reward" KeyError in epoch_stat["test_reward"] when saving checkpint through trainer information #913

@127161782

Description

@127161782
  • I have marked all applicable categories:
    • exception-raising bug
    • RL algorithm bug
    • documentation request (i.e. "X is missing from the documentation.")
    • new feature request
  • I have visited the source website
  • I have searched through the issue tracker for duplicates
  • I have mentioned version numbers, operating system and environment, where applicable:
    import tianshou, gymnasium as gym, torch, numpy, sys
    print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)

0.4.11 0.28.1 1.12.1 1.23.5 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:18)
[GCC 10.3.0] linux

Below is part of my code:
```python
  def save_check_history_best_point__fn(epoch, epoch_stat, info):
      # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
      rwd = epoch_stat["test_reward"]
      tmp_path = os.path.join(log_path, f"chkpt_{epoch}_{rwd}")
      if not os.path.exists(tmp_path):
          makedirs(tmp_path)
      ckpt_path = os.path.join(tmp_path, "checkpoint.pth")
      # Example: saving by epoch num
      # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
      torch.save(
          {
              "model": policy.state_dict(),
              "optim": optim.state_dict(),
          }, ckpt_path
      )
      return ckpt_path
 # trainer
  trainer = OnpolicyTrainer(
      policy,
      train_collector,
      test_collector,
      args.epoch,
      args.step_per_epoch,
      args.repeat_per_collect,
      args.test_num,
      args.batch_size,
      episode_per_collect=args.episode_per_collect,
      stop_fn=stop_fn,
      save_best_fn=save_best_fn,
      logger=tensorboard_logger,
      resume_from_log=args.resume,
      # save_checkpoint_fn=save_checkpoint_fn,
  )

  for epoch, epoch_stat, info in trainer:
      logger.info(f"Epoch:{epoch},epoch_stat:{epoch_stat},info:{info}")
      if epoch_stat["test_reward"] >= epoch_stat["best_reward"]:
          save_check_history_best_point__fn(epoch, epoch_stat, info)

Errors are as follows:

Epoch #4: 163810it [01:29, 1840.23it/s, env_step=647572, len=959, n/ep=16, n/st=15359, rew=5747.34]                                                                                                        
Traceback (most recent call last):
  File "ppo.py", line 270, in <module>
    test_ppo()
  File "ppo.py", line 244, in test_ppo
    if epoch_stat["test_reward"] >= epoch_stat["best_reward"]:
KeyError: 'test_reward'

My logs are as follows:

2023-08-15 14:23:16,832 [test_ppo: ppo.py,243] INFO: Epoch:1,epoch_stat:{'test_reward': 5037.995699169552, 'test_reward_std': 1727.1946382581957, 'best_reward': 5142.779300127992, 'best_reward_std': 1656.0104516781603, 'best_epoch': 0, 'loss': 0.030237332731485368, 'loss/clip': 0.018954908354207874, 'loss/vf': 0.045129697881639, 'loss/ent': -11.023293628692628, 'gradient_step': 2462, 'env_step': 158340, 'rew': 5676.072051357914, 'len': 983, 'n/ep': 16, 'n/st': 15743},info:{'duration': '98.45s', 'train_time/model': '24.10s', 'test_step': 455647, 'test_episode': 500, 'test_time': '40.49s', 'test_speed': '11252.92 step/s', 'best_reward': 5142.779300127992, 'best_result': '5142.78 ± 1656.01', 'train_step': 158340, 'train_episode': 176, 'train_time/collector': '33.86s', 'train_speed': '2731.75 step/s'}
2023-08-15 14:25:20,525 [test_ppo: ppo.py,243] INFO: Epoch:2,epoch_stat:{'test_reward': 4965.068494208408, 'test_reward_std': 1797.844330770504, 'best_reward': 5142.779300127992, 'best_reward_std': 1656.0104516781603, 'best_epoch': 0, 'loss': 0.04299093971960247, 'loss/clip': 0.02695085948333144, 'loss/vf': 0.06416032150387764, 'loss/ent': -10.992865056991578, 'gradient_step': 5038, 'env_step': 323648, 'rew': 5876.510187073267, 'len': 979, 'n/ep': 16, 'n/st': 15666},info:{'duration': '222.15s', 'train_time/model': '49.13s', 'test_step': 1166256, 'test_episode': 1300, 'test_time': '104.56s', 'test_speed': '11153.57 step/s', 'best_reward': 5142.779300127992, 'best_result': '5142.78 ± 1656.01', 'train_step': 323648, 'train_episode': 352, 'train_time/collector': '68.46s', 'train_speed': '2752.48 step/s'}
2023-08-15 14:26:26,616 [test_ppo: ppo.py,243] INFO: Epoch:3,epoch_stat:{'test_reward': 5325.084627629465, 'test_reward_std': 1489.5582454379198, 'best_reward': 5325.084627629465, 'best_reward_std': 1489.5582454379198, 'best_epoch': 3, 'loss': 0.043803104497492314, 'loss/clip': 0.023312768880277873, 'loss/vf': 0.08196134265512228, 'loss/ent': -11.039208106994629, 'gradient_step': 7530, 'env_step': 483762, 'rew': 5179.871382976435, 'len': 942, 'n/ep': 16, 'n/st': 15078},info:{'duration': '288.24s', 'train_time/model': '73.12s', 'test_step': 1257045, 'test_episode': 1400, 'test_time': '112.72s', 'test_speed': '11151.51 step/s', 'best_reward': 5325.084627629465, 'best_result': '5325.08 ± 1489.56', 'train_step': 483762, 'train_episode': 528, 'train_time/collector': '102.39s', 'train_speed': '2756.26 step/s'}
2023-08-15 14:27:55,634 [test_ppo: ppo.py,243] INFO: Epoch:4,epoch_stat:{'loss': 0.035433512907475234, 'loss/clip': 0.022555462159216402, 'loss/vf': 0.05151220228523016, 'loss/ent': -11.06116304397583, 'gradient_step': 9840, 'env_step': 647572, 'rew': 5747.335802369741, 'len': 959, 'n/ep': 16, 'n/st': 15359},info:{'duration': '377.26s', 'train_time/model': '95.45s', 'test_step': 1624828, 'test_episode': 1800, 'test_time': '145.51s', 'test_speed': '11166.17 step/s', 'best_reward': 5602.059048982654, 'best_result': '5602.06 ± 1151.13', 'train_step': 647572, 'train_episode': 704, 'train_time/collector': '136.29s', 'train_speed': '2794.36 step/s'}

I suspect that when the reward_threshold is reached, self.stop_fn_flag becomes true, causing the subsequent program not to go:

 epoch_stat.update(test_stat),

epoch_stat is not updated to get test_reward.
The code points are as follows:
https://github.com/thu-ml/tianshou/blob/80a698be529639bebfa398a280f7140a6bc16998/tianshou/trainer/base.py#L288C21-L288C72
https://github.com/thu-ml/tianshou/blob/master/tianshou/trainer/base.py#L309C1-L317C49

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingnot reproduced yetNot yet tested or reproduced by a reviewer

    Type

    No type

    Projects

    Status

    Done

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions