这是indexloc提供的服务,不要输入任何密码
Skip to content

Fix handling of torch "device" association #810

@jamartinh

Description

@jamartinh
  • I have marked all applicable categories:
    • exception-raising bug
    • RL algorithm bug
    • documentation request (i.e. "X is missing from the documentation.")
    • new feature request
  • I have visited the source website
  • I have searched through the issue tracker for duplicates
  • I have mentioned version numbers, operating system and environment, where applicable:
    import tianshou, gym, torch, numpy, sys
    print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)

0.4.11 0.21.0 1.12.1.post200 1.23.5 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:23:14) [GCC 10.4.0] linux

I think a simple, clear and robust mechanism should be designed to move policies between devices.

The only current way I have found to move a saved SAC policy on CUDA to CPU is the following:

   policy: SACPolicy = torch.load(model_file_name, map_location="cpu")
   policy.to(device)
   policy.device =device

   policy.actor.device = device
   policy.actor.preprocess.device = device
   policy.actor.preprocess.model.device =tdevice
   policy.actor.mu.device = device
   policy.actor.sigma.device = device


   policy.critic1.device = device
   policy.critic1.preprocess.device = device
   policy.critic1.preprocess.model.device = device
   policy.critic1.last.device = device

   policy.critic1_old.device = device
   policy.critic1_old.preprocess.device = device
   policy.critic1_old.preprocess.model.device =tdevice
   policy.critic1_old.last.device = device

   policy.critic2.device = device
   policy.critic2.preprocess.device = device
   policy.critic2.preprocess.model.device = device
   policy.critic2.last.device = device

   policy.critic2_old.device = device
   policy.critic2_old.preprocess.device = device
   policy.critic2_old.preprocess.model.device = device
   policy.critic2_old.last.device = device


   policy.actor_optim.load_state_dict(policy.actor_optim.state_dict())
   policy.critic1_optim.load_state_dict(policy.critic1_optim.state_dict())
   policy.critic2_optim.load_state_dict(policy.critic2_optim.state_dict())

Some ideas may include, to specify any device variable in any (sub) model as a reference to a shared variable residing in just one place.

Another idea is to make .device a function that makes some call to some variable etc.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinggood first issueGood for newcomers

    Type

    No type

    Projects

    Status

    Done

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions