这是indexloc提供的服务,不要输入任何密码
Skip to content

Is it possible for DQN to use MultiDiscrete action space design? #643

@127161782

Description

@127161782
import gym
class EnvExample(gym.Env):
    def __init__(self, max=30, render=0):
        self.state = []
        self.max= max
        self.action_space = gym.spaces.MultiDiscrete([2] * 9)   # 
        self.end = False
        # self.observation_space = gym.spaces.MultiDiscrete(obs)
        self.observation_space = gym.spaces.Box(low=0, high=3, shape=(9,), dtype=np.float32)

    def reset(self):
        return self.state

    def step(self, action):
        '''action =[1 0 0 1 0 1 1 0 1]'''
        state = self.observation_space.sample()
        reward = 1
        if self.max<= 0:
            self.end = True
        self.max -= 1
        return state, reward, self.end, {}

    def render(self, mode="human"):
        print("".join("___"))

The Discrete action space was used in the past, but now the action space has become larger, and discrete is not enough. I want to pass 0/1 encoding as action like this. But the collection of data is an error,like this:

   train_collector.collect(n_step=args.batch_size * args.training_num)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/tianshou/data/collector.py", line 234, in collect
    result = self.policy(self.data, last_state)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/tianshou/policy/modelfree/dqn.py", line 163, in forward
    logits, hidden = model(obs_next, state=state, info=batch.info)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/tianshou/utils/net/common.py", line 195, in forward
    logits = self.model(obs)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/tianshou/utils/net/common.py", line 97, in forward
    return self.model(obs.flatten(1))  # type: ignore
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (20x0 and 9x128)

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions