-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
import gym
class EnvExample(gym.Env):
def __init__(self, max=30, render=0):
self.state = []
self.max= max
self.action_space = gym.spaces.MultiDiscrete([2] * 9) #
self.end = False
# self.observation_space = gym.spaces.MultiDiscrete(obs)
self.observation_space = gym.spaces.Box(low=0, high=3, shape=(9,), dtype=np.float32)
def reset(self):
return self.state
def step(self, action):
'''action =[1 0 0 1 0 1 1 0 1]'''
state = self.observation_space.sample()
reward = 1
if self.max<= 0:
self.end = True
self.max -= 1
return state, reward, self.end, {}
def render(self, mode="human"):
print("".join("___"))
The Discrete action space was used in the past, but now the action space has become larger, and discrete is not enough. I want to pass 0/1 encoding as action like this. But the collection of data is an error,like this:
train_collector.collect(n_step=args.batch_size * args.training_num)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/tianshou/data/collector.py", line 234, in collect
result = self.policy(self.data, last_state)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/tianshou/policy/modelfree/dqn.py", line 163, in forward
logits, hidden = model(obs_next, state=state, info=batch.info)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/tianshou/utils/net/common.py", line 195, in forward
logits = self.model(obs)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/tianshou/utils/net/common.py", line 97, in forward
return self.model(obs.flatten(1)) # type: ignore
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/zl/.conda/envs/mytorch/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (20x0 and 9x128)
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested