-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Closed
Labels
bugSomething isn't workingSomething isn't workinggood first issueGood for newcomersGood for newcomers
Milestone
Description
- I have marked all applicable categories:
- exception-raising bug
- RL algorithm bug
- documentation request (i.e. "X is missing from the documentation.")
- new feature request
- I have visited the source website
- I have searched through the issue tracker for duplicates
- I have mentioned version numbers, operating system and environment, where applicable:
import tianshou, gym, torch, numpy, sys print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
0.4.11 0.21.0 1.12.1.post200 1.23.5 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:23:14) [GCC 10.4.0] linux
I think a simple, clear and robust mechanism should be designed to move policies between devices.
The only current way I have found to move a saved SAC policy on CUDA to CPU is the following:
policy: SACPolicy = torch.load(model_file_name, map_location="cpu")
policy.to(device)
policy.device =device
policy.actor.device = device
policy.actor.preprocess.device = device
policy.actor.preprocess.model.device =tdevice
policy.actor.mu.device = device
policy.actor.sigma.device = device
policy.critic1.device = device
policy.critic1.preprocess.device = device
policy.critic1.preprocess.model.device = device
policy.critic1.last.device = device
policy.critic1_old.device = device
policy.critic1_old.preprocess.device = device
policy.critic1_old.preprocess.model.device =tdevice
policy.critic1_old.last.device = device
policy.critic2.device = device
policy.critic2.preprocess.device = device
policy.critic2.preprocess.model.device = device
policy.critic2.last.device = device
policy.critic2_old.device = device
policy.critic2_old.preprocess.device = device
policy.critic2_old.preprocess.model.device = device
policy.critic2_old.last.device = device
policy.actor_optim.load_state_dict(policy.actor_optim.state_dict())
policy.critic1_optim.load_state_dict(policy.critic1_optim.state_dict())
policy.critic2_optim.load_state_dict(policy.critic2_optim.state_dict())
Some ideas may include, to specify any device variable in any (sub) model as a reference to a shared variable residing in just one place.
Another idea is to make .device a function that makes some call to some variable etc.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinggood first issueGood for newcomersGood for newcomers
Type
Projects
Status
Done