这是indexloc提供的服务,不要输入任何密码
Skip to content

Using graph models with structured env observations #1249

@mjmartis

Description

@mjmartis

I'd like to train graph NNs, which typically require examples in a structured format. Examples also need to be batched with a matrix diagonalisation trick, but I might be able to handle that myself in my custom Net subclass.

I can't find a way for an environment to pass structured data either as observations or in the "info" dictionary. Models don't receive the full observation dict (only an extracted array), and the info dict isn't always available.

Here's a minimal example demonstrating my issue:

from venv import logger
import tianshou as ts
import gymnasium as gym
import torch


class TestNet(ts.utils.net.common.Net):
  ''' Prints args provided to the forward method. '''

  def __init__(self, s_shape, a_shape, h_sizes):
    super().__init__(state_shape=s_shape, action_shape=a_shape, hidden_sizes=h_sizes)

  def forward(self, obs, state, info):
    logger.warning(f'Model received: obs={obs} state={state} info={info}')
    return super().forward(obs, state, info)


class TestEnv(gym.Env):
  ''' Uses a dict state space. '''

  def __init__(self):
    super().__init__()
    self.action_space = gym.spaces.Discrete(2)
    self.observation_space = gym.spaces.Dict({
        # It appears the value of the 'obs' key is what is passed to the model.
        'obs': gym.spaces.Box(low=0, high=1, shape=(2,)),
        'extra_data': gym.spaces.Box(low=0, high=1, shape=(2,)),
    })

  def reset(self, seed=None, options=None):
    super().reset(seed=seed)
    return self.observation_space.sample(), {}

  def step(self, action):
    return self.observation_space.sample(), 0, True, False, {}


def test():
  env = TestEnv()

  net = TestNet(s_shape=(2, ), a_shape=2, h_sizes=[16, 16])
  optim = torch.optim.Adam(net.parameters(), lr=1e-2)
  policy = ts.policy.DQNPolicy(model=net, optim=optim, action_space=env.action_space)

  collector = ts.data.Collector(policy, env)
  collector.reset()
  collector.collect(n_step=1)
  # Prints:
  #   Model received: obs=[[0.15468411 0.8790275 ]] state=None info=Batch()

  test_collector = ts.data.Collector(policy, env)
  test_collector.reset()

  ts.trainer.OffpolicyTrainer(
      policy=policy,
      train_collector=collector,
      test_collector=test_collector,
      max_epoch=1,
      step_per_epoch=1,
      step_per_collect=1,
      batch_size=1,
      episode_per_test=1,
  ).run()
  # Prints:
  #   Model received: obs=[[0.5437955  0.99715394]] state=None info=[None]

test()

The first print shows how my model can't access structured observations and the second shows why I can't include my structured data in the info dict.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions