diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md index 0f1e7f9a2..95ae47e05 100644 --- a/examples/mujoco/README.md +++ b/examples/mujoco/README.md @@ -12,9 +12,10 @@ For each supported algorithm and supported mujoco environments, we provide: Supported algorithms are listed below: -- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0) -- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0) -- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0) +- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) +- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) +- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec) +- [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9) ## Offpolicy algorithms @@ -109,10 +110,44 @@ By comparison to both classic literature and open source implementations (e.g., ## Onpolicy Algorithms -TBD +### REINFORCE + +| Environment | Tianshou(10M steps) | +| :--------------------: | :-----------------: | +| Ant | **1108.1±323.1** | +| HalfCheetah | **1138.8±104.7** | +| Hopper | **416.0±104.7** | +| Walker2d | **440.9±148.2** | +| Swimmer | **35.6±2.6** | +| Humanoid | **464.3±58.4** | +| Reacher | **-5.5±0.2** | +| InvertedPendulum | **1000.0±0.0** | +| InvertedDoublePendulum | **7726.2±1287.3** | + + +| Environment | Tianshou(3M steps) | [SpinningUp (VPG Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html)[[10]](#footnote10) | +| :--------------------: | :--------------------------: | :------------------------: | +| Ant | **474.9+-133.5** | ~5 | +| HalfCheetah | **884.0+-41.0** | ~600 | +| Hopper | 395.8+-64.5* | **~800** | +| Walker2d | 412.0+-52.4 | **~460** | +| Swimmer | 35.3+-1.4 | **~51** | +| Humanoid | **438.2+-47.8** | N | +| Reacher | **-10.5+-0.7** | N | +| InvertedPendulum | **999.2+-2.4** | N | +| InvertedDoublePendulum | **1059.7+-307.7** | N | +\* details[[5]](#footnote5)[[6]](#footnote6) +#### Hints for REINFORCE +0. Following [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990), we downscale last layer of policy network by a factor of 0.01 after orthogonal initialization. +1. We choose "tanh" function to squash sampled action from range (-inf, inf) to (-1, 1) rather than usually used clipping method (As in StableBaselines3). We did full scale ablation studies and results show that tanh squashing performs a tiny little bit better than clipping overall, and is much better than no action bounding. However, "clip" method is still a very good method, considering its simplicity. +2. We use global observation normalization and global rew-to-go (value) normalization by default. Both are crucial to good performances of REINFORCE algorithm. Since we minus mean when doing rew-to-go normalization, you can treat global mean of rew-to-go as a naive version of "baseline". +3. Since we do not have a value estimator, we use global rew-to-go mean to bootstrap truncated steps because of timelimit and unfinished collecting, while most other implementations use 0. We feel this would help because mean is more likely a better estimate than 0 (no ablation study has been done). +4. We have done full scale ablation study on learning rate and lr decay strategy. We experiment with lr of 3e-4, 5e-4, 1e-3, each have 2 options: no lr decay or linear decay to 0. Experiments show that 3e-4 learning rate will cause slowly learning and make agent step in local optima easily for certain environments like InvertedDoublePendulum, Ant, HalfCheetah, and 1e-3 lr helps a lot. However, after training agents with lr 1e-3 for 5M steps or so, agents in certain environments like InvertedPendulum will become unstable. Conclusion is that we should start with a large learning rate and linearly decay it, but for a small initial learning rate or if you only train agents for limited timesteps, DO NOT decay it. +5. We didn't tune `step-per-collect` option and `training-num` option. Default values are finetuned with PPO algorithm so we assume they are also good for REINFORCE. You can play with them if you want, but remember that `buffer-size` should always be larger than `step-per-collect`, and if `step-per-collect` is too small and `training-num` too large, episodes will be truncated and bootstrapped very often, which will harm performances. If `training-num` is too small (e.g., less than 8), speed will go down. +6. Sigma of action is not fixed (normally seen in other implementation) or conditioned on observation, but is an independent parameter which can be updated by gradient descent. We choose this setting because it works well in PPO, and is recommended by [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990). See Fig. 23. ## Note @@ -126,10 +161,12 @@ TBD [5] ~ means the number is approximated from the graph because accurate numbers is not provided in the paper. N means graphs not provided. -[6] Reward metric: The meaning of the table value is the max average return over 10 trails (different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34). +[6] Reward metric: The meaning of the table value is the max average return over 10 trails (different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered, if not otherwise stated. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34). [7] In TD3 paper, shaded region represents only half of standard deviation. [8] SAC's start-timesteps is set to 10000 by default while it is 25000 is DDPG/TD3. TD3's learning rate is set to 3e-4 while it is 1e-3 for DDPG/SAC. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because of SpinningUp) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can! [9] We use batchsize of 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `step_per_collect`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc. + +[10] Comparing Tianshou's REINFORCE algorithm with SpinningUp's VPG is quite unfair because SpinningUp's VPG uses a generative advantage estimator (GAE) which requires a dnn value predictor (critic network), which makes so called "VPG" more like A2C (advantage actor critic) algorithm. Even so, you can see that we are roughly at-parity with each other even if tianshou's REINFORCE do not use a critic or GAE. diff --git a/examples/mujoco/benchmark/Ant-v3/figure.png b/examples/mujoco/benchmark/Ant-v3/offpolicy.png similarity index 100% rename from examples/mujoco/benchmark/Ant-v3/figure.png rename to examples/mujoco/benchmark/Ant-v3/offpolicy.png diff --git a/examples/mujoco/benchmark/Ant-v3/reinforce/figure.png b/examples/mujoco/benchmark/Ant-v3/reinforce/figure.png new file mode 100644 index 000000000..f747258f8 Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/reinforce/figure.png differ diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/offpolicy.png similarity index 100% rename from examples/mujoco/benchmark/HalfCheetah-v3/figure.png rename to examples/mujoco/benchmark/HalfCheetah-v3/offpolicy.png diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/reinforce/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/reinforce/figure.png new file mode 100644 index 000000000..08554a6a5 Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/reinforce/figure.png differ diff --git a/examples/mujoco/benchmark/Hopper-v3/figure.png b/examples/mujoco/benchmark/Hopper-v3/offpolicy.png similarity index 100% rename from examples/mujoco/benchmark/Hopper-v3/figure.png rename to examples/mujoco/benchmark/Hopper-v3/offpolicy.png diff --git a/examples/mujoco/benchmark/Hopper-v3/reinforce/figure.png b/examples/mujoco/benchmark/Hopper-v3/reinforce/figure.png new file mode 100644 index 000000000..cf5157b51 Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/reinforce/figure.png differ diff --git a/examples/mujoco/benchmark/Humanoid-v3/figure.png b/examples/mujoco/benchmark/Humanoid-v3/offpolicy.png similarity index 100% rename from examples/mujoco/benchmark/Humanoid-v3/figure.png rename to examples/mujoco/benchmark/Humanoid-v3/offpolicy.png diff --git a/examples/mujoco/benchmark/Humanoid-v3/reinforce/figure.png b/examples/mujoco/benchmark/Humanoid-v3/reinforce/figure.png new file mode 100644 index 000000000..96600188e Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/reinforce/figure.png differ diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/offpolicy.png similarity index 100% rename from examples/mujoco/benchmark/InvertedDoublePendulum-v2/figure.png rename to examples/mujoco/benchmark/InvertedDoublePendulum-v2/offpolicy.png diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/reinforce/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/reinforce/figure.png new file mode 100644 index 000000000..7c9a964e2 Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/reinforce/figure.png differ diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/offpolicy.png similarity index 100% rename from examples/mujoco/benchmark/InvertedPendulum-v2/figure.png rename to examples/mujoco/benchmark/InvertedPendulum-v2/offpolicy.png diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/reinforce/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/reinforce/figure.png new file mode 100644 index 000000000..e5c1e0a86 Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/reinforce/figure.png differ diff --git a/examples/mujoco/benchmark/README.md b/examples/mujoco/benchmark/README.md new file mode 100644 index 000000000..66fb65187 --- /dev/null +++ b/examples/mujoco/benchmark/README.md @@ -0,0 +1,37 @@ +# Benchmark Result + +## Ant-v3 + +![](Ant-v3/offpolicy.png) + +## HalfCheetah-v3 + +![](HalfCheetah-v3/offpolicy.png) + +## Hopper-v3 + +![](Hopper-v3/offpolicy.png) + +## Walker2d-v3 + +![](Walker2d-v3/offpolicy.png) + +## Swimmer-v3 + +![](Swimmer-v3/offpolicy.png) + +## Humanoid-v3 + +![](Humanoid-v3/offpolicy.png) + +## Reacher-v2 + +![](Reacher-v2/offpolicy.png) + +## InvertedPendulum-v2 + +![](InvertedPendulum-v2/offpolicy.png) + +## InvertedDoublePendulum-v2 + +![](InvertedDoublePendulum-v2/offpolicy.png) diff --git a/examples/mujoco/benchmark/Reacher-v2/figure.png b/examples/mujoco/benchmark/Reacher-v2/offpolicy.png similarity index 100% rename from examples/mujoco/benchmark/Reacher-v2/figure.png rename to examples/mujoco/benchmark/Reacher-v2/offpolicy.png diff --git a/examples/mujoco/benchmark/Reacher-v2/reinforce/figure.png b/examples/mujoco/benchmark/Reacher-v2/reinforce/figure.png new file mode 100644 index 000000000..d87bcd8b9 Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/reinforce/figure.png differ diff --git a/examples/mujoco/benchmark/Swimmer-v3/figure.png b/examples/mujoco/benchmark/Swimmer-v3/offpolicy.png similarity index 100% rename from examples/mujoco/benchmark/Swimmer-v3/figure.png rename to examples/mujoco/benchmark/Swimmer-v3/offpolicy.png diff --git a/examples/mujoco/benchmark/Swimmer-v3/reinforce/figure.png b/examples/mujoco/benchmark/Swimmer-v3/reinforce/figure.png new file mode 100644 index 000000000..9ca015118 Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/reinforce/figure.png differ diff --git a/examples/mujoco/benchmark/Walker2d-v3/figure.png b/examples/mujoco/benchmark/Walker2d-v3/offpolicy.png similarity index 100% rename from examples/mujoco/benchmark/Walker2d-v3/figure.png rename to examples/mujoco/benchmark/Walker2d-v3/offpolicy.png diff --git a/examples/mujoco/benchmark/Walker2d-v3/reinforce/figure.png b/examples/mujoco/benchmark/Walker2d-v3/reinforce/figure.png new file mode 100644 index 000000000..84c2c0c76 Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/reinforce/figure.png differ diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 9f5af3e4b..fae9b00f9 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -103,14 +103,16 @@ def test_ddpg(args=get_args()): test_collector = Collector(policy, test_envs) train_collector.collect(n_step=args.start_timesteps, random=True) # log - log_path = os.path.join(args.logdir, args.task, 'ddpg', 'seed_' + str( - args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S')) + log_path = os.path.join(args.logdir, args.task, 'ddpg', 'seed_' + str(args.seed) + + '_' + datetime.datetime.now().strftime('%m%d_%H%M%S') + + '-' + args.task.replace('-', '_') + '_ddpg') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py new file mode 100755 index 000000000..f7fc28c85 --- /dev/null +++ b/examples/mujoco/mujoco_reinforce.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 + +import os +import gym +import torch +import datetime +import argparse +import numpy as np +from torch import nn +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.tensorboard import SummaryWriter +from torch.distributions import Independent, Normal + +from tianshou.policy import PGPolicy +from tianshou.utils import BasicLogger +from tianshou.env import SubprocVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import onpolicy_trainer +from tianshou.utils.net.continuous import ActorProb +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='HalfCheetah-v3') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=4096) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=30000) + parser.add_argument('--step-per-collect', type=int, default=2048) + parser.add_argument('--repeat-per-collect', type=int, default=1) + # batch-size >> step-per-collect means caculating all data in one singe forward. + parser.add_argument('--batch-size', type=int, default=99999) + parser.add_argument('--training-num', type=int, default=64) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--resume-path', type=str, default=None) + # reinforce special + parser.add_argument('--rew-norm', type=int, default=True) + # "clip" option also works well. + parser.add_argument('--action-bound-method', type=str, default="tanh") + parser.add_argument('--lr-decay', type=int, default=True) + return parser.parse_args() + + +def test_reinforce(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), + np.max(env.action_space.high)) + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)], + norm_obs=True) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)], + norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, device=args.device) + actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, + unbounded=True, device=args.device).to(args.device) + torch.nn.init.constant_(actor.sigma_param, -0.5) + for m in actor.modules(): + if isinstance(m, torch.nn.Linear): + # orthogonal initialization + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + # do last policy layer scaling, this will make initial actions have (close to) + # 0 mean and std, and will help boost performances, + # see https://arxiv.org/abs/2006.05990, Fig.24 for details + for m in actor.mu.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.zeros_(m.bias) + m.weight.data.copy_(0.01 * m.weight.data) + + optim = torch.optim.Adam(actor.parameters(), lr=args.lr) + lr_scheduler = None + if args.lr_decay: + # decay learning rate to 0 linearly + max_update_num = np.ceil( + args.step_per_epoch / args.step_per_collect) * args.epoch + + lr_scheduler = LambdaLR( + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + + def dist(*logits): + return Independent(Normal(*logits), 1) + + policy = PGPolicy(actor, optim, dist, discount_factor=args.gamma, + reward_normalization=args.rew_norm, action_scaling=True, + action_bound_method=args.action_bound_method, + lr_scheduler=lr_scheduler, action_space=env.action_space) + + # collector + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_reinforce' + log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer, update_interval=10) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + # trainer + result = onpolicy_trainer( + policy, train_collector, test_collector, args.epoch, args.step_per_epoch, + args.repeat_per_collect, args.test_num, args.batch_size, + step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, + test_in_train=False) + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == '__main__': + test_reinforce() diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index cf57318a8..64685ad45 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -115,14 +115,16 @@ def test_sac(args=get_args()): test_collector = Collector(policy, test_envs) train_collector.collect(n_step=args.start_timesteps, random=True) # log - log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str( - args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S')) + log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(args.seed) + + '_' + datetime.datetime.now().strftime('%m%d_%H%M%S') + + '-' + args.task.replace('-', '_') + '_sac') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index fd7c4eae6..1b33076cb 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -117,14 +117,16 @@ def test_td3(args=get_args()): test_collector = Collector(policy, test_envs) train_collector.collect(n_step=args.start_timesteps, random=True) # log - log_path = os.path.join(args.logdir, args.task, 'td3', 'seed_' + str( - args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S')) + log_path = os.path.join(args.logdir, args.task, 'td3', 'seed_' + str(args.seed) + + '_' + datetime.datetime.now().strftime('%m%d_%H%M%S') + + '-' + args.task.replace('-', '_') + '_td3') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, diff --git a/examples/mujoco/run_experiments.sh b/examples/mujoco/run_experiments.sh index 4de3263f5..3f0f8a9c5 100755 --- a/examples/mujoco/run_experiments.sh +++ b/examples/mujoco/run_experiments.sh @@ -4,7 +4,8 @@ LOGDIR="results" TASK=$1 echo "Experiments started." -for seed in $(seq 1 10) +for seed in $(seq 0 9) do python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1 done +echo "Experiments ended."