-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
>>> import tianshou, gym, torch, numpy, sys
>>>
>>> print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
0.4.11 0.23.1 1.12.0+cu113 1.22.3 3.8.13 (default, Mar 28 2022, 11:38:47)
[GCC 7.5.0] linux
>>>
When running the test file "test/discrete/test_ppo.py", I got this error:
Traceback (most recent call last):
File "/me4data/dql/quant/alpha/model/agent/test/test_ppo.py", line 160, in <module>
test_ppo()
File "/me4data/dql/quant/alpha/model/agent/test/test_ppo.py", line 132, in test_ppo
result = onpolicy_trainer(
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/onpolicy.py", line 150, in onpolicy_trainer
return OnpolicyTrainer(*args, **kwargs).run()
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/base.py", line 441, in run
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/base.py", line 252, in __iter__
self.reset()
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/base.py", line 237, in reset
test_result = test_episode(
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/utils.py", line 27, in test_episode
result = collector.collect(n_episode=n_episode)
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/data/collector.py", line 297, in collect
result = self.policy(self.data, last_state)
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/policy/modelfree/pg.py", line 112, in forward
dist = self.dist_fn(logits)
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/torch/distributions/categorical.py", line 64, in __init__
super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/torch/distributions/distribution.py", line 55, in __init__
raise ValueError(
ValueError: Expected parameter probs (Tensor of shape (100, 2)) of distribution Categorical(probs: torch.Size([100, 2])) to satisfy the constraint Simplex(), but found invalid values:
It seems that we should use the kwargs to call the "dist_fn" to fix this bug?
dist = self.dist_fn(logits=logits)
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested