-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Description
I'd like to train graph NNs, which typically require examples in a structured format. Examples also need to be batched with a matrix diagonalisation trick, but I might be able to handle that myself in my custom Net subclass.
I can't find a way for an environment to pass structured data either as observations or in the "info" dictionary. Models don't receive the full observation dict (only an extracted array), and the info dict isn't always available.
Here's a minimal example demonstrating my issue:
from venv import logger
import tianshou as ts
import gymnasium as gym
import torch
class TestNet(ts.utils.net.common.Net):
''' Prints args provided to the forward method. '''
def __init__(self, s_shape, a_shape, h_sizes):
super().__init__(state_shape=s_shape, action_shape=a_shape, hidden_sizes=h_sizes)
def forward(self, obs, state, info):
logger.warning(f'Model received: obs={obs} state={state} info={info}')
return super().forward(obs, state, info)
class TestEnv(gym.Env):
''' Uses a dict state space. '''
def __init__(self):
super().__init__()
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Dict({
# It appears the value of the 'obs' key is what is passed to the model.
'obs': gym.spaces.Box(low=0, high=1, shape=(2,)),
'extra_data': gym.spaces.Box(low=0, high=1, shape=(2,)),
})
def reset(self, seed=None, options=None):
super().reset(seed=seed)
return self.observation_space.sample(), {}
def step(self, action):
return self.observation_space.sample(), 0, True, False, {}
def test():
env = TestEnv()
net = TestNet(s_shape=(2, ), a_shape=2, h_sizes=[16, 16])
optim = torch.optim.Adam(net.parameters(), lr=1e-2)
policy = ts.policy.DQNPolicy(model=net, optim=optim, action_space=env.action_space)
collector = ts.data.Collector(policy, env)
collector.reset()
collector.collect(n_step=1)
# Prints:
# Model received: obs=[[0.15468411 0.8790275 ]] state=None info=Batch()
test_collector = ts.data.Collector(policy, env)
test_collector.reset()
ts.trainer.OffpolicyTrainer(
policy=policy,
train_collector=collector,
test_collector=test_collector,
max_epoch=1,
step_per_epoch=1,
step_per_collect=1,
batch_size=1,
episode_per_test=1,
).run()
# Prints:
# Model received: obs=[[0.5437955 0.99715394]] state=None info=[None]
test()
The first print shows how my model can't access structured observations and the second shows why I can't include my structured data in the info dict.
Metadata
Metadata
Assignees
Labels
No labels