这是indexloc提供的服务,不要输入任何密码
Skip to content

Improvements pertaining to the handling of multi-experiment creation #1131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,42 @@
## Release 1.1.0

### Api Extensions
- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063
- `Collector`s can now be closed, and their reset is more granular. #1063
- Trainers can control whether collectors should be reset prior to training. #1063
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
- `SamplingConfig` supports `batch_size=None`. #1077
- Batch received new methods: `to_numpy_` and `to_torch_`. #1098, #1117
- `to_dict` in Batch supports also non-recursive conversion. #1098
- Batch `__eq__` implemented, semantic equality check of batches is now possible. #1098
- `data`:
- `Batch`:
- Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098
- Add methods `to_numpy_` and `to_torch_`. #1098, #1117
- Add `__eq__` (semantic equality check). #1098
- `data.collector`:
- `Collector`:
- Add method `close` #1063
- Method `reset` is now more granular (new flags controlling behavior). #1063
- `CollectStats`: Add convenience constructor `with_autogenerated_stats`. #1063
- `trainer`:
- Trainers can now control whether collectors should be reset prior to training. #1063
- `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105.
- `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called. #1074
- When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. #1074
- New method in `ExperimentConfig` called `build_default_seeded_experiments`. #1074
- `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074
- New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!).
Launchers for parallelization currently in alpha state. #1074
- `highlevel`:
- `SamplingConfig`:
- Add support for `batch_size=None`. #1077
- Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074
- `highlevel.experiment`:
- `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and
which determines the default run name and therefore the persistence subdirectory.
It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than
`experiment_name` (although the latter will still be interpreted correctly). #1074 #1131
- Add class `ExperimentCollection` for the convenient execution of multiple experiment runs #1131
- `ExperimentBuilder`:
- Add method `build_seeded_collection` for the sound creation of multiple
experiments with varying random seeds #1131
- Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074
- The module `evaluation.launchers` for parallelization is currently in alpha state.
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
- `continuous.Critic`:
- Add flag `apply_preprocess_net_to_obs_only` to allow the
preprocessing network to be applied to the observations only (without
the actions concatenated), which is essential for the case where we want
to reuse the actor's preprocessing network #1128
- `utils.net`:
- `continuous.Critic`:
- Add flag `apply_preprocess_net_to_obs_only` to allow the
preprocessing network to be applied to the observations only (without
the actions concatenated), which is essential for the case where we want
to reuse the actor's preprocessing network #1128

### Fixes
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_dqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def main(
)

experiment = builder.build()
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_iqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main(
.with_epoch_stop_callback(AtariEpochStopCallback(task))
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_a2c_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def main(
.with_critic_factory_default(hidden_sizes, nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_ddpg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(
.with_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_npg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
135 changes: 35 additions & 100 deletions examples/mujoco/mujoco_ppo_hl_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

import os
import sys
from collections.abc import Sequence
from typing import Literal

import torch

Expand All @@ -41,86 +39,30 @@


def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v4",
num_experiments: int = 5,
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 3e-4,
gamma: float = 0.99,
epoch: int = 3,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 10,
batch_size: int = 64,
training_num: int = 10,
test_num: int = 10,
rew_norm: bool = True,
vf_coef: float = 0.25,
ent_coef: float = 0.0,
gae_lambda: float = 0.95,
bound_action_method: Literal["clip", "tanh"] | None = "clip",
lr_decay: bool = True,
max_grad_norm: float = 0.5,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
norm_adv: bool = False,
recompute_adv: bool = True,
num_experiments: int = 2,
run_experiments_sequentially: bool = True,
) -> str:
"""Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for
a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained,
the results are evaluated using the rliable API.

:param experiment_config:
:param task: a mujoco task name
:param num_experiments: how many experiments to run with different seeds
:param buffer_size:
:param hidden_sizes:
:param lr:
:param gamma:
:param epoch:
:param step_per_epoch:
:param step_per_collect:
:param repeat_per_collect:
:param batch_size:
:param training_num:
:param test_num:
:param rew_norm:
:param vf_coef:
:param ent_coef:
:param gae_lambda:
:param bound_action_method:
:param lr_decay:
:param max_grad_norm:
:param eps_clip:
:param dual_clip:
:param value_clip:
:param norm_adv:
:param recompute_adv:
:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
) -> RLiableExperimentResult:
""":param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
LIMITATIONS: currently, the parallel execution does not seem to work properly on linux.
It might generally be undesired to run multiple experiments in parallel on the same machine,
as a single experiment already uses all available CPU cores by default.
:return: the directory where the results are stored
"""
task = "Ant-v4"
persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag()))

experiment_config.persistence_base_dir = persistence_dir
log.info(f"Will save all experiment results to {persistence_dir}.")
experiment_config.watch = False
experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False)

sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
num_test_episodes=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
num_epochs=1,
step_per_epoch=5000,
batch_size=64,
num_train_envs=10,
num_test_envs=10,
num_test_episodes=10,
buffer_size=4096,
step_per_collect=2048,
repeat_per_collect=10,
)

env_factory = MujocoEnvFactory(
Expand All @@ -133,52 +75,45 @@ def main(
else VectorEnvType.SUBPROC_SHARED_MEM,
)

experiments = (
hidden_sizes = (64, 64)

experiment_collection = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ppo_params(
PPOParams(
discount_factor=gamma,
gae_lambda=gae_lambda,
action_bound_method=bound_action_method,
reward_normalization=rew_norm,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
value_clip=value_clip,
advantage_normalization=norm_adv,
eps_clip=eps_clip,
dual_clip=dual_clip,
recompute_advantage=recompute_adv,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
discount_factor=0.99,
gae_lambda=0.95,
action_bound_method="clip",
reward_normalization=True,
ent_coef=0.0,
vf_coef=0.25,
max_grad_norm=0.5,
value_clip=False,
advantage_normalization=False,
eps_clip=0.2,
dual_clip=None,
recompute_advantage=True,
lr=3e-4,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config),
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.with_logger_factory(LoggerFactoryDefault("tensorboard"))
.build_default_seeded_experiments(num_experiments)
.build_seeded_collection(num_experiments)
)

if run_experiments_sequentially:
launcher = RegisteredExpLauncher.sequential.create_launcher()
else:
launcher = RegisteredExpLauncher.joblib.create_launcher()
launcher.launch(experiments)

return persistence_dir

experiment_collection.run(launcher)

def eval_experiments(log_dir: str) -> RLiableExperimentResult:
"""Evaluate the experiments in the given log directory using the rliable API."""
rliable_result = RLiableExperimentResult.load_from_disk(log_dir)
rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir)
rliable_result.eval_results(show_plots=True, save_plots=True)
return rliable_result


if __name__ == "__main__":
log_dir = logging.run_cli(main, level=logging.INFO)
assert isinstance(log_dir, str) # for mypy
evaluation_result = eval_experiments(log_dir)
result = logging.run_cli(main, level=logging.INFO)
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_redq_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def main(
.with_critic_ensemble_factory_default(hidden_sizes)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_reinforce_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main(
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_td3_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def main(
.with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_trpo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(override_experiment_name=log_name)
experiment.run(run_name=log_name)


if __name__ == "__main__":
Expand Down
Loading