这是indexloc提供的服务,不要输入任何密码
Skip to content

Draft: for explicit seed mechanism of train seed, specific test seeds #1031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions examples/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.highlevel.experiment import TrainSeedMechanism
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World

Expand All @@ -17,7 +18,11 @@
log = logging.getLogger(__name__)


def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
def make_mujoco_env(task: str, seed: int, num_train_envs: int,
num_test_envs: int, obs_norm: bool,
train_seed_mechanism: TrainSeedMechanism = TrainSeedMechanism.NONE,
test_seeds: tuple[int, ...] | None = None
): #makes mujoco envs, name is not really honest
"""Wrapper function for Mujoco env.

If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.
Expand All @@ -27,6 +32,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
if envpool is not None:
train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
#todo robert check how seeding is done here
else:
warnings.warn(
"Recommend using envpool (pip install envpool) "
Expand All @@ -35,8 +41,21 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
env = gym.make(task)
train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
train_envs.seed(seed)
test_envs.seed(seed)
if train_seed_mechanism.is_consecutive():
train_envs.seed([seed + i for i in range(num_train_envs)])
elif train_seed_mechanism.is_repeat():
train_envs.seed([seed for _ in range(num_train_envs)])
elif train_seed_mechanism.is_none():
train_envs.seed(seed)
else:
NotImplementedError(f"train_seed_mechanism {train_seed_mechanism} not implemented")

if test_seeds is None:
test_envs.seed(seed)
else:
assert len(test_seeds) == num_test_envs
test_envs.seed(test_seeds)

if obs_norm:
# obs norm wrapper
train_envs = VectorEnvNormObs(train_envs)
Expand Down Expand Up @@ -69,10 +88,14 @@ def restore(self, event: RestoreEvent, world: World):


class MujocoEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, obs_norm=True):
def __init__(self, task: str, seed: int, obs_norm=True,
train_seed_mechanism: TrainSeedMechanism = TrainSeedMechanism.NONE,
test_seeds: tuple[int, ...] | None = None):
self.task = task
self.seed = seed
self.obs_norm = obs_norm
self.train_seed_mechanism = train_seed_mechanism
self.test_seeds = test_seeds

def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
Expand All @@ -81,8 +104,10 @@ def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousE
num_train_envs=num_training_envs,
num_test_envs=num_test_envs,
obs_norm=self.obs_norm,
train_seed_mechanism=self.train_seed_mechanism,
test_seeds=self.test_seeds,
)
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs, test_seeds = self.test_seeds)
if self.obs_norm:
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs
103 changes: 103 additions & 0 deletions examples/mujoco/mujoco_ppo_hl_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python3

import os
from collections.abc import Sequence
from typing import Literal

import torch

from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
EvaluationProtocolExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
experiment_config: EvaluationProtocolExperimentConfig,
task: str = "Ant-v4",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 3e-4,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 10,
batch_size: int = 64,
training_num: int = 3, #64,
test_num: int = 2, #10,
rew_norm: bool = True,
vf_coef: float = 0.25,
ent_coef: float = 0.0,
gae_lambda: float = 0.95,
bound_action_method: Literal["clip", "tanh"] | None = "clip",
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
norm_adv: bool = False,
recompute_adv: bool = True,
):
b = 5
log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag())

sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
)

env_factory = MujocoEnvFactory(task,
experiment_config.seed,
obs_norm=True,
train_seed_mechanism=experiment_config.train_seed_mechanism,
test_seeds=experiment_config.test_seeds)

experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ppo_params(
PPOParams(
discount_factor=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
reward_normalization=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
advantage_normalization=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_advantage=recompute_adv,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name, experiment_config.record_seed_of_transition_to_buffer_test)


if __name__ == "__main__":
logging.run_cli(main)
71 changes: 71 additions & 0 deletions test/highlevel/test_eval_procedure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory

import pytest

from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
PPOExperimentBuilder,
ExperimentConfig,
EvaluationProtocolExperimentConfig,
TrainSeedMechanism
)
from examples.mujoco.mujoco_env import MujocoEnvFactory


def test_standard_experiment_config():
experiment_config = ExperimentConfig
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
env_factory = MujocoEnvFactory(task="Ant-v4",
seed=experiment_config.seed,
obs_norm=True,)

ppo = PPOExperimentBuilder(
experiment_config=experiment_config,
env_factory=env_factory,
sampling_config=sampling_config,
)
experiment = ppo.build()
experiment.run("test")
print(experiment)


@pytest.mark.parametrize("experiment_config",
[
EvaluationProtocolExperimentConfig(
persistence_enabled=False,
train_seed_mechanism=TrainSeedMechanism.CONSECUTIVE,
test_seeds=(2,3)),
EvaluationProtocolExperimentConfig(persistence_enabled=False, train_seed_mechanism=TrainSeedMechanism.REPEAT,
test_seeds=(2,3)),
EvaluationProtocolExperimentConfig(persistence_enabled=False, train_seed_mechanism=TrainSeedMechanism.NONE,
test_seeds=(2,3)),

],
)
def test_experiment_builder_continuous_default_params(experiment_config):
sampling_config = SamplingConfig(
num_epochs=1,
step_per_epoch=100,
num_train_envs=2,
num_test_envs=2,
)
env_factory = MujocoEnvFactory(task="Ant-v4",
seed=experiment_config.seed,
obs_norm=True,
train_seed_mechanism=experiment_config.train_seed_mechanism,
test_seeds=experiment_config.test_seeds)

ppo = PPOExperimentBuilder(
experiment_config=experiment_config,
env_factory=env_factory,
sampling_config=sampling_config,
)
experiment = ppo.build()
experiment.run("test",
record_seed_of_transition_to_buffer_test=experiment_config.record_seed_of_transition_to_buffer_test)
print(experiment)
8 changes: 7 additions & 1 deletion tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
buffer: ReplayBuffer | None = None,
preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None,
exploration_noise: bool = False,
record_seed_of_transition_to_buffer_test: bool = False,
test_seeds: tuple[int, ...] | None = None,
) -> None:
super().__init__()
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
Expand All @@ -111,6 +113,8 @@ def __init__(
self.data: RolloutBatchProtocol
# avoid creating attribute outside __init__
self.reset(False)
self.record_seed_of_transition_to_buffer_test = record_seed_of_transition_to_buffer_test
self.test_seeds = test_seeds

def _assign_buffer(self, buffer: ReplayBuffer | None) -> None:
"""Check if the buffer matches the constraint."""
Expand Down Expand Up @@ -317,7 +321,9 @@ def collect(
ready_env_ids,
)
done = np.logical_or(terminated, truncated)

if self.record_seed_of_transition_to_buffer_test:
for env_id, active_id in enumerate(ready_env_ids):
info[env_id]["seed"] = self.test_seeds[active_id]
self.data.update(
obs_next=obs_next,
rew=rew,
Expand Down
6 changes: 3 additions & 3 deletions tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,19 +315,19 @@ def step(
np.stack(info_list),
)

def seed(self, seed: int | list[int] | None = None) -> list[list[int] | None]:
def seed(self, seed: int | list[int] |tuple[int]| None = None) -> list[list[int] | None]:
"""Set the seed for all environments.

Accept ``None``, an int (which will extend ``i`` to
``[i, i + 1, i + 2, ...]``) or a list.

:return: The list of seeds used in this env's random number generators.
The first value in the list should be the "main" seed, or the value
which a reproducer pass to "seed".
which a reproducer passes to "seed".
"""
self._assert_is_not_closed()
seed_list: list[None] | list[int]
if seed is None:
if seed is None: #todo robert check this can happen, results in error when put in config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forbid seed=None?

seed_list = [seed] * self.env_num
elif isinstance(seed, int):
seed_list = [seed + i for i in range(self.env_num)]
Expand Down
7 changes: 6 additions & 1 deletion tianshou/highlevel/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def create_train_test_collector(
self,
policy: BasePolicy,
envs: Environments,
record_seed_of_transition_to_buffer_test: bool = False,
test_seeds: tuple[int, ...] | None = None,
) -> tuple[Collector, Collector]:
buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs
Expand All @@ -113,7 +115,10 @@ def create_train_test_collector(
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
test_collector = Collector(policy, envs.test_envs,
record_seed_of_transition_to_buffer_test=record_seed_of_transition_to_buffer_test,
test_seeds = test_seeds)
self.test_seeds = test_seeds
if self.sampling_config.start_timesteps > 0:
log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
Expand Down
4 changes: 3 additions & 1 deletion tianshou/highlevel/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,11 @@ def get_type(self) -> EnvType:
class ContinuousEnvironments(Environments):
"""Represents (vectorized) continuous environments."""

def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv,
test_seeds: tuple[int, ...]|None = None):
super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
self.test_seeds = test_seeds

@staticmethod
def from_factory(
Expand Down
Loading