这是indexloc提供的服务,不要输入任何密码
Skip to content

found invalid values to satisfy the constraint Simplex() for distribution Categorical #823

@qldeng

Description

@qldeng
>>> import tianshou, gym, torch, numpy, sys
>>> 
>>> print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
0.4.11 0.23.1 1.12.0+cu113 1.22.3 3.8.13 (default, Mar 28 2022, 11:38:47) 
[GCC 7.5.0] linux
>>> 

When running the test file "test/discrete/test_ppo.py", I got this error:

Traceback (most recent call last):
  File "/me4data/dql/quant/alpha/model/agent/test/test_ppo.py", line 160, in <module>
    test_ppo()
  File "/me4data/dql/quant/alpha/model/agent/test/test_ppo.py", line 132, in test_ppo
    result = onpolicy_trainer(
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/onpolicy.py", line 150, in onpolicy_trainer
    return OnpolicyTrainer(*args, **kwargs).run()
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/base.py", line 441, in run
    deque(self, maxlen=0)  # feed the entire iterator into a zero-length deque
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/base.py", line 252, in __iter__
    self.reset()
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/base.py", line 237, in reset
    test_result = test_episode(
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/trainer/utils.py", line 27, in test_episode
    result = collector.collect(n_episode=n_episode)
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/data/collector.py", line 297, in collect
    result = self.policy(self.data, last_state)
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/tianshou/policy/modelfree/pg.py", line 112, in forward
    dist = self.dist_fn(logits)
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/torch/distributions/categorical.py", line 64, in __init__
    super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
  File "/me4data/dql/miniconda3/envs/pt/lib/python3.8/site-packages/torch/distributions/distribution.py", line 55, in __init__
    raise ValueError(
ValueError: Expected parameter probs (Tensor of shape (100, 2)) of distribution Categorical(probs: torch.Size([100, 2])) to satisfy the constraint Simplex(), but found invalid values:

It seems that we should use the kwargs to call the "dist_fn" to fix this bug?

dist = self.dist_fn(logits=logits)

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions